mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
65 Commits
feat/umi
...
feat/dummy
| Author | SHA1 | Date | |
|---|---|---|---|
| c868777752 | |||
| 8847e75c55 | |||
| 8429d2ccfa | |||
| 6794ca2ba8 | |||
| 98c2152f08 | |||
| f92999aeb9 | |||
| 5659c77988 | |||
| fd88a3acda | |||
| 6deabe4b71 | |||
| 2f3525c4a2 | |||
| d04061def7 | |||
| 07ee578c78 | |||
| 636e2264c3 | |||
| 5a4c168d92 | |||
| 047f89cc2a | |||
| 4d64733846 | |||
| 0c3ed6ca7a | |||
| 44322fa726 | |||
| e041634bee | |||
| 6b6c0623cc | |||
| 6db3afca6f | |||
| 433ccc9603 | |||
| 9e92337f24 | |||
| 99eea2ae03 | |||
| ac33f20e51 | |||
| ab0a9c3d7a | |||
| 9616c44024 | |||
| 60b432b0f1 | |||
| 513e6c0046 | |||
| 60362b9c7c | |||
| 5915649eac | |||
| 675880392d | |||
| d0123c4178 | |||
| e86afc883e | |||
| d10b7787eb | |||
| ac1816ee9c | |||
| 25fb16ea7a | |||
| 7baf909e32 | |||
| 79ffe316e4 | |||
| 68b2142bd2 | |||
| a42fb4d0e2 | |||
| 83f1de035e | |||
| e09a6a90e1 | |||
| 10cc9dd961 | |||
| 41b8d4b7c6 | |||
| 7939fc3ddf | |||
| 11b35dfa11 | |||
| b27570039c | |||
| 55c4cc1b27 | |||
| 3fb3edde3f | |||
| 43bf1fb763 | |||
| c7a26f5070 | |||
| aaa308b158 | |||
| 84df6cd13d | |||
| 26db4b64d8 | |||
| 2204a45020 | |||
| b6df884d08 | |||
| bb23dafad1 | |||
| c409ed2d1d | |||
| d20ef2e46e | |||
| 05189361b6 | |||
| 896779003c | |||
| b55bc62ef0 | |||
| 08ff689a1e | |||
| 0acdde4ae2 |
@@ -0,0 +1,263 @@
|
||||
# RTC Profiling Guide
|
||||
|
||||
This guide explains how to profile RTC (Real-Time Chunking) performance to identify bottlenecks and understand why RTC might be slower than expected.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Profile with Real Robot (Profiled Version)
|
||||
|
||||
Use `eval_with_real_robot_profiled.py` to profile actual robot execution:
|
||||
|
||||
```bash
|
||||
# With RTC enabled
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
|
||||
# Without RTC for comparison
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
**Output**: At the end of execution, you'll see a detailed breakdown of timing for each component:
|
||||
- `get_actions.policy_inference` - Time spent in policy inference
|
||||
- `get_actions.preprocessing` - Time spent preprocessing observations
|
||||
- `get_actions.postprocessing` - Time spent postprocessing actions
|
||||
- `get_actions.action_queue_merge` - Time spent merging actions with RTC
|
||||
- `robot.get_observation` - Time to get observations from robot
|
||||
- `robot.send_action` - Time to send actions to robot
|
||||
- And more...
|
||||
|
||||
### 2. Profile Without Robot (Comparison Script)
|
||||
|
||||
Use `profile_rtc_comparison.py` to profile just the policy inference without needing a robot:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Output**: Side-by-side comparison of performance with and without RTC, including:
|
||||
- Mean/min/max inference times
|
||||
- Throughput (iterations per second)
|
||||
- Verdict on whether RTC is faster or slower
|
||||
|
||||
### 3. Enable Detailed Method-Level Profiling
|
||||
|
||||
For even more granular profiling, add the `--enable_detailed_profiling` flag:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--enable_detailed_profiling
|
||||
```
|
||||
|
||||
This will show timing for individual methods within the policy.
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
### Key Metrics to Look At
|
||||
|
||||
1. **get_actions.policy_inference** - This should be the largest component
|
||||
- If RTC is enabled, this includes the RTC guidance overhead
|
||||
- Compare this with/without RTC to see the overhead
|
||||
|
||||
2. **get_actions.preprocessing** - Image preprocessing and normalization
|
||||
- Should be relatively fast
|
||||
- If slow, consider optimizing image processing
|
||||
|
||||
3. **get_actions.postprocessing** - Action denormalization
|
||||
- Should be minimal
|
||||
- If slow, check postprocessor implementation
|
||||
|
||||
4. **get_actions.action_queue_merge** - RTC-specific merging logic
|
||||
- Only present when RTC is enabled
|
||||
- If this is taking significant time, the RTC algorithm may need optimization
|
||||
|
||||
5. **robot.get_observation** - Robot communication overhead
|
||||
- If slow, check camera/sensor latency
|
||||
- Consider reducing image resolution
|
||||
|
||||
6. **robot.send_action** - Action execution overhead
|
||||
- Should be very fast
|
||||
- If slow, check robot communication
|
||||
|
||||
### Expected Performance
|
||||
|
||||
For a typical Pi0 policy on Apple Silicon (MPS):
|
||||
- **Without RTC**: ~100-200ms per inference
|
||||
- **With RTC**: Should be similar or slightly faster due to action reuse
|
||||
- **Preprocessing**: ~5-20ms depending on number of cameras
|
||||
- **Postprocessing**: ~1-5ms
|
||||
|
||||
If RTC is significantly slower, likely causes:
|
||||
1. **RTC overhead exceeds benefits** - The guidance computation is expensive
|
||||
2. **Execution horizon too small** - Not reusing enough actions to amortize overhead
|
||||
3. **No compilation** - Try with `--use_torch_compile`
|
||||
4. **Large prev_actions buffer** - Copying/processing previous actions is slow
|
||||
|
||||
## Profiling Your Own Code
|
||||
|
||||
### Using the Profiling Decorator
|
||||
|
||||
Add profiling to your own methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling, print_profiling_summary
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Decorate methods you want to profile
|
||||
@profile_method
|
||||
def my_slow_function(x):
|
||||
# ... your code ...
|
||||
return result
|
||||
|
||||
# At end of execution
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Using Profile Context Manager
|
||||
|
||||
For profiling specific code blocks:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_section, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("data_loading"):
|
||||
data = load_data()
|
||||
|
||||
with profile_section("model_inference"):
|
||||
output = model(data)
|
||||
```
|
||||
|
||||
### Adding Profiling to Policy Methods
|
||||
|
||||
To profile specific parts of the Pi0 policy, you can add decorators:
|
||||
|
||||
```python
|
||||
# In src/lerobot/policies/pi0/modeling_pi0.py
|
||||
from lerobot.utils.profiling import profile_method, profile_section
|
||||
|
||||
class Pi0Policy:
|
||||
@profile_method
|
||||
def predict_action_chunk(self, obs, inference_delay=0, prev_chunk_left_over=None):
|
||||
# ... existing code ...
|
||||
pass
|
||||
|
||||
def _generate_actions_with_rtc(self, ...):
|
||||
with profile_section("rtc.guidance_computation"):
|
||||
# ... guidance code ...
|
||||
pass
|
||||
|
||||
with profile_section("rtc.action_merging"):
|
||||
# ... merging code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## Analyzing Results
|
||||
|
||||
### Comparison Checklist
|
||||
|
||||
When comparing RTC vs non-RTC performance, check:
|
||||
|
||||
- [ ] Is `policy_inference` time higher with RTC?
|
||||
- [ ] Is `action_queue_merge` taking significant time?
|
||||
- [ ] Are you running enough iterations to amortize warmup?
|
||||
- [ ] Is torch.compile enabled for fair comparison?
|
||||
- [ ] Is the execution horizon large enough? (should be >= 10-20)
|
||||
- [ ] Are you testing on the same hardware/device?
|
||||
|
||||
### Common Bottlenecks
|
||||
|
||||
1. **Image preprocessing dominates**
|
||||
- Solution: Reduce image resolution, use fewer cameras, or optimize preprocessing
|
||||
|
||||
2. **Action queue operations are slow**
|
||||
- Solution: Review queue implementation, consider using ring buffer
|
||||
|
||||
3. **RTC guidance is expensive**
|
||||
- Solution: Reduce guidance weight, simplify guidance computation, use torch.compile
|
||||
|
||||
4. **Robot communication is slow**
|
||||
- Solution: Increase baud rate, reduce action frequency, optimize protocol
|
||||
|
||||
5. **Memory allocation overhead**
|
||||
- Solution: Pre-allocate buffers, reuse tensors, avoid unnecessary copies
|
||||
|
||||
## Advanced: Adding Custom Metrics
|
||||
|
||||
You can add custom timing metrics to the profiled script:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import record_timing
|
||||
|
||||
start = time.perf_counter()
|
||||
# ... your code ...
|
||||
duration = time.perf_counter() - start
|
||||
record_timing("my_custom_metric", duration)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Profiling shows RTC is slower by >50%
|
||||
|
||||
1. Check if torch.compile is enabled: `--use_torch_compile`
|
||||
2. Increase execution horizon: `--rtc.execution_horizon=30`
|
||||
3. Verify inference_delay is calculated correctly
|
||||
4. Profile with `--enable_detailed_profiling` to find exact bottleneck
|
||||
|
||||
### Profiling output is empty
|
||||
|
||||
1. Make sure profiling is enabled with `enable_profiling()`
|
||||
2. Verify you're running enough iterations (at least 10)
|
||||
3. Check that code is actually executing (not short-circuited)
|
||||
|
||||
### Inconsistent results between runs
|
||||
|
||||
1. Run more iterations: `--num_iterations=100`
|
||||
2. Increase warmup iterations
|
||||
3. Check for thermal throttling on device
|
||||
4. Ensure no other processes competing for resources
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run both profiling scripts (with/without robot)
|
||||
2. Compare timing breakdowns
|
||||
3. Identify the largest bottleneck
|
||||
4. Focus optimization efforts on that component
|
||||
5. Re-run profiling to verify improvements
|
||||
|
||||
## Questions?
|
||||
|
||||
If profiling reveals unexpected bottlenecks or you need help interpreting results, please share:
|
||||
- The full profiling output
|
||||
- Your configuration (RTC enabled/disabled, execution horizon, etc.)
|
||||
- Hardware specs (device type, memory, etc.)
|
||||
- Policy type and size
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
# RTC Profiling - Quick Start
|
||||
|
||||
Quick reference for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 🚀 Quick Commands
|
||||
|
||||
### 1. Profile with Real Robot
|
||||
|
||||
```bash
|
||||
# With RTC enabled (profiled version)
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0}, front: {type: opencv, index_or_path: 1}}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
### 2. Compare RTC vs No-RTC (No Robot Needed)
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
### 3. Detailed RTC Method Profiling
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
## 📊 What Each Tool Does
|
||||
|
||||
| Tool | Purpose | Needs Robot? |
|
||||
|------|---------|--------------|
|
||||
| `eval_with_real_robot_profiled.py` | Profile actual robot execution with RTC | ✅ Yes |
|
||||
| `profile_rtc_comparison.py` | Compare RTC vs no-RTC side-by-side | ❌ No |
|
||||
| `profile_pi0_rtc_detailed.py` | Deep dive into RTC internals | ❌ No |
|
||||
|
||||
## 🔍 Key Metrics to Watch
|
||||
|
||||
### Overall Performance
|
||||
- **iteration.policy_inference** - Total policy inference time
|
||||
- **iteration.preprocessing** - Image preprocessing time
|
||||
- **iteration.postprocessing** - Action denormalization time
|
||||
|
||||
### RTC-Specific (with `--enable_rtc_profiling`)
|
||||
- **rtc.denoise_step.base_denoising** - Time without RTC overhead
|
||||
- **rtc.denoise_step.autograd_correction** - Gradient computation time
|
||||
- **rtc.denoise_step.guidance_computation** - Total RTC guidance overhead
|
||||
|
||||
### Robot Communication
|
||||
- **robot.get_observation** - Time to get robot state
|
||||
- **robot.send_action** - Time to send action command
|
||||
|
||||
## 🎯 Quick Diagnosis
|
||||
|
||||
### RTC is slower than expected?
|
||||
|
||||
1. **Check if torch.compile is enabled**
|
||||
```bash
|
||||
# Add this flag
|
||||
--use_torch_compile
|
||||
```
|
||||
|
||||
2. **Try larger execution horizon**
|
||||
```bash
|
||||
# Increase to amortize RTC overhead
|
||||
--rtc.execution_horizon=30
|
||||
```
|
||||
|
||||
3. **Profile to find bottleneck**
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
### Preprocessing is slow?
|
||||
|
||||
- Reduce image resolution in robot config
|
||||
- Use fewer cameras
|
||||
- Check camera FPS settings
|
||||
|
||||
### Policy inference is slow?
|
||||
|
||||
- Enable torch.compile
|
||||
- Check device (MPS vs CUDA vs CPU)
|
||||
- Try smaller model if available
|
||||
|
||||
## 📈 Expected Performance
|
||||
|
||||
### Typical timings on Apple Silicon (MPS):
|
||||
|
||||
| Component | Time (ms) | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| Policy inference | 100-200 | Depends on model size |
|
||||
| Preprocessing | 5-20 | Depends on #cameras |
|
||||
| Postprocessing | 1-5 | Usually fast |
|
||||
| RTC overhead | 10-50 | Should be < 50% of base |
|
||||
|
||||
### When RTC helps:
|
||||
- ✅ Execution horizon ≥ 10
|
||||
- ✅ Inference time > action execution rate
|
||||
- ✅ Using torch.compile
|
||||
- ✅ Proper inference_delay calculation
|
||||
|
||||
### When RTC might not help:
|
||||
- ❌ Very fast inference already
|
||||
- ❌ Small execution horizon (< 5)
|
||||
- ❌ No compilation (interpreted mode)
|
||||
- ❌ Inference delay not accounted for
|
||||
|
||||
## 🛠️ Adding Profiling to Your Code
|
||||
|
||||
### Quick snippet:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary, profile_section
|
||||
|
||||
# Enable at start
|
||||
enable_profiling()
|
||||
|
||||
# Profile sections
|
||||
with profile_section("my_operation"):
|
||||
# ... your code ...
|
||||
pass
|
||||
|
||||
# Print at end
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Profile specific methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method
|
||||
|
||||
@profile_method
|
||||
def my_slow_function():
|
||||
# ... your code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## 📝 Example Output
|
||||
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms)
|
||||
--------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23
|
||||
iteration.preprocessing 20 12.45
|
||||
rtc.denoise_step.guidance_computation 200 15.67
|
||||
rtc.denoise_step.autograd_correction 200 8.23
|
||||
rtc.denoise_step.base_denoising 200 120.45
|
||||
================================================================================
|
||||
```
|
||||
|
||||
## 🚨 Common Issues
|
||||
|
||||
### "No profiling data available"
|
||||
- Did you call `enable_profiling()`?
|
||||
- Running enough iterations?
|
||||
|
||||
### Inconsistent results
|
||||
- Increase `--num_iterations`
|
||||
- Check for thermal throttling
|
||||
- Close other applications
|
||||
|
||||
### Can't find bottleneck
|
||||
- Enable `--enable_rtc_profiling` for detailed breakdown
|
||||
- Check both preprocessing and inference
|
||||
- Compare with and without RTC
|
||||
|
||||
## 📖 More Details
|
||||
|
||||
See `PROFILING_GUIDE.md` for comprehensive documentation.
|
||||
|
||||
## 🤔 Still Slow?
|
||||
|
||||
1. Run comparison: `profile_rtc_comparison.py`
|
||||
2. Run detailed profiling: `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Share output for help (include device, model, settings)
|
||||
|
||||
## ✅ Quick Checklist
|
||||
|
||||
Before asking for help, verify:
|
||||
|
||||
- [ ] Ran comparison script (with/without RTC)
|
||||
- [ ] Tried torch.compile
|
||||
- [ ] Tested different execution horizons (10, 20, 30)
|
||||
- [ ] Profiled with detailed RTC profiling
|
||||
- [ ] Checked preprocessing vs inference split
|
||||
- [ ] Verified hardware (device type, thermal state)
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
# RTC Profiling Toolkit
|
||||
|
||||
Complete toolkit for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 📦 What's Included
|
||||
|
||||
### Scripts
|
||||
|
||||
1. **`eval_with_real_robot_profiled.py`**
|
||||
- Profiled version of the real robot eval script
|
||||
- Adds timing measurements throughout execution
|
||||
- Works with actual robot hardware
|
||||
- Same usage as original but with profiling output
|
||||
|
||||
2. **`profile_rtc_comparison.py`**
|
||||
- Side-by-side comparison of RTC vs no-RTC
|
||||
- No robot needed (uses mock observations)
|
||||
- Shows clear verdict on whether RTC is helping
|
||||
- Great for quick performance checks
|
||||
|
||||
3. **`profile_pi0_rtc_detailed.py`**
|
||||
- Most detailed profiling available
|
||||
- Can enable RTC method-level profiling
|
||||
- Provides insights and recommendations
|
||||
- Perfect for deep-dive investigations
|
||||
|
||||
4. **`add_rtc_profiling.py`**
|
||||
- Monkey-patching utility for RTC internals
|
||||
- Profiles individual RTC operations
|
||||
- Can be applied without modifying source
|
||||
- Shows exactly where RTC spends time
|
||||
|
||||
### Utilities
|
||||
|
||||
5. **`src/lerobot/utils/profiling.py`**
|
||||
- Core profiling utilities
|
||||
- Decorators for method profiling
|
||||
- Context managers for code blocks
|
||||
- Statistics collection and reporting
|
||||
|
||||
### Documentation
|
||||
|
||||
6. **`PROFILING_GUIDE.md`** - Comprehensive guide
|
||||
7. **`PROFILING_QUICK_START.md`** - Quick reference
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Step 1: Compare Performance
|
||||
|
||||
Run this first to see if RTC is actually slower:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
COMPARISON SUMMARY
|
||||
================================================================================
|
||||
Metric Without RTC With RTC Difference
|
||||
--------------------------------------------------------------------------------
|
||||
Mean time (ms) 150.23 165.45 +15.22
|
||||
Throughput (iter/s) 6.66 6.05 -0.61
|
||||
================================================================================
|
||||
VERDICT
|
||||
✗ RTC is SLOWER by 10.1%
|
||||
Mean time increased by 15.22 ms
|
||||
|
||||
Possible reasons:
|
||||
- RTC overhead exceeds benefits at current execution horizon
|
||||
- No torch.compile enabled
|
||||
```
|
||||
|
||||
### Step 2: Identify Bottleneck
|
||||
|
||||
If RTC is slower, find out why:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms) Total (s)
|
||||
------------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23 3.00
|
||||
rtc.denoise_step.guidance_computation 200 15.67 3.13
|
||||
rtc.denoise_step.autograd_correction 200 8.23 1.65
|
||||
iteration.preprocessing 20 12.45 0.25
|
||||
================================================================================
|
||||
|
||||
KEY INSIGHTS
|
||||
================================================================================
|
||||
Time breakdown:
|
||||
Policy inference: 150.23 ms (87.2%)
|
||||
Preprocessing: 12.45 ms (7.2%)
|
||||
Postprocessing: 2.10 ms (1.2%)
|
||||
|
||||
RTC breakdown:
|
||||
Base denoising: 120.45 ms
|
||||
Guidance compute: 15.67 ms
|
||||
Autograd correct: 8.23 ms
|
||||
RTC overhead: 23.90 ms (19.8% of base)
|
||||
|
||||
Recommendations:
|
||||
⚠ RTC autograd overhead is significant
|
||||
→ This is expected, but consider increasing execution_horizon
|
||||
→ Try torch.compile if not already enabled
|
||||
💡 torch.compile not enabled
|
||||
→ Try --use_torch_compile for potential speedup
|
||||
================================================================================
|
||||
```
|
||||
|
||||
### Step 3: Try Optimizations
|
||||
|
||||
Based on recommendations:
|
||||
|
||||
```bash
|
||||
# Try with torch.compile
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--use_torch_compile
|
||||
|
||||
# Try larger execution horizon
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=30
|
||||
```
|
||||
|
||||
### Step 4: Profile Real Robot (Optional)
|
||||
|
||||
Test with actual hardware:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{...}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
## 🎯 Common Scenarios
|
||||
|
||||
### "RTC is 2x slower!"
|
||||
|
||||
This usually means:
|
||||
- RTC overhead is high but not getting benefits
|
||||
- Need to enable torch.compile
|
||||
- Execution horizon too small
|
||||
- Inference delay not calculated correctly
|
||||
|
||||
**Try:**
|
||||
1. `--use_torch_compile`
|
||||
2. Increase `--execution_horizon` to 30+
|
||||
3. Check inference_delay calculation
|
||||
|
||||
### "RTC is only slightly slower"
|
||||
|
||||
This is expected! RTC overhead is about 10-30% typically.
|
||||
The benefit comes during **execution**, not single inference:
|
||||
- Actions are reused across chunks
|
||||
- Overall system latency is reduced
|
||||
- Robot gets smoother actions
|
||||
|
||||
### "Want to optimize specific part"
|
||||
|
||||
Use the profiling utilities:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, profile_section, print_profiling_summary
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("my_custom_operation"):
|
||||
# Your code here
|
||||
pass
|
||||
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
## 📊 Understanding Results
|
||||
|
||||
### Key Metrics
|
||||
|
||||
**Policy Inference Time**
|
||||
- Time for forward pass through model
|
||||
- Should be largest component (70-90%)
|
||||
- Includes RTC guidance if enabled
|
||||
|
||||
**Preprocessing Time**
|
||||
- Image normalization, resizing
|
||||
- Should be < 20% of total
|
||||
- If high: reduce image resolution
|
||||
|
||||
**RTC Guidance Overhead**
|
||||
- Extra time for RTC guidance computation
|
||||
- Typically 10-30% of base inference
|
||||
- If > 50%: RTC may not be beneficial at current settings
|
||||
|
||||
**Autograd Correction**
|
||||
- Time computing gradients for RTC
|
||||
- Usually 5-15% of base inference
|
||||
- Can be reduced with torch.compile
|
||||
|
||||
### Expected Ranges (Apple Silicon MPS)
|
||||
|
||||
| Metric | Good | Acceptable | Poor |
|
||||
|--------|------|------------|------|
|
||||
| Policy inference | 100-150ms | 150-250ms | >250ms |
|
||||
| Preprocessing | <20ms | 20-50ms | >50ms |
|
||||
| RTC overhead | 10-30% | 30-50% | >50% |
|
||||
|
||||
## 🔧 Optimization Guide
|
||||
|
||||
### If RTC overhead is too high:
|
||||
|
||||
1. **Enable compilation:**
|
||||
```bash
|
||||
--use_torch_compile
|
||||
```
|
||||
Expected improvement: 20-40% faster
|
||||
|
||||
2. **Increase execution horizon:**
|
||||
```bash
|
||||
--execution_horizon=30 # or higher
|
||||
```
|
||||
Amortizes RTC cost over more actions
|
||||
|
||||
3. **Check guidance weight:**
|
||||
```python
|
||||
# In config
|
||||
rtc.max_guidance_weight=1.0 # try 0.5 for less overhead
|
||||
```
|
||||
|
||||
### If preprocessing is slow:
|
||||
|
||||
1. **Reduce image resolution:**
|
||||
```python
|
||||
# In robot config
|
||||
cameras={
|
||||
"gripper": {"width": 320, "height": 240} # instead of 640x480
|
||||
}
|
||||
```
|
||||
|
||||
2. **Use fewer cameras:**
|
||||
- Profile which cameras are essential
|
||||
- Remove unnecessary views
|
||||
|
||||
### If inference is generally slow:
|
||||
|
||||
1. Use torch.compile (if not already)
|
||||
2. Check device is correct (MPS vs CUDA)
|
||||
3. Verify model is in eval mode
|
||||
4. Check for unnecessary gradient tracking
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Empty profiling output
|
||||
```python
|
||||
# Make sure to enable profiling!
|
||||
from lerobot.utils.profiling import enable_profiling
|
||||
enable_profiling()
|
||||
```
|
||||
|
||||
### Inconsistent timings
|
||||
- Run more iterations (50-100)
|
||||
- Check thermal throttling
|
||||
- Close background apps
|
||||
- Use `--warmup_iterations=10`
|
||||
|
||||
### Can't find bottleneck
|
||||
1. Start with `profile_rtc_comparison.py`
|
||||
2. Then run `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Compare with/without RTC
|
||||
4. Check each component separately
|
||||
|
||||
## 📖 Full Documentation
|
||||
|
||||
- **`PROFILING_GUIDE.md`** - Complete reference with examples
|
||||
- **`PROFILING_QUICK_START.md`** - Quick commands and tips
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
If you're still experiencing issues:
|
||||
|
||||
1. Run comparison script and save output
|
||||
2. Run detailed profiling and save output
|
||||
3. Include:
|
||||
- Policy path
|
||||
- Device type
|
||||
- RTC settings (execution_horizon, etc.)
|
||||
- Hardware specs
|
||||
- Full profiling output
|
||||
|
||||
## 🎓 Learning More
|
||||
|
||||
### Profiling your own code:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
@profile_method
|
||||
def my_function():
|
||||
# Automatically profiled
|
||||
pass
|
||||
```
|
||||
|
||||
### RTC internals:
|
||||
|
||||
```python
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# Now RTC methods are profiled
|
||||
policy.predict_action_chunk(...)
|
||||
```
|
||||
|
||||
## ✨ Next Steps
|
||||
|
||||
1. Run `profile_rtc_comparison.py` to establish baseline
|
||||
2. Use `profile_pi0_rtc_detailed.py` to find bottlenecks
|
||||
3. Apply optimizations (torch.compile, larger horizon)
|
||||
4. Re-run comparison to verify improvements
|
||||
5. Test with real robot using profiled version
|
||||
|
||||
Happy profiling! 🚀
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
# Real-Time Chunking (RTC) Examples
|
||||
|
||||
This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions.
|
||||
|
||||
**Key Benefits:**
|
||||
|
||||
- Maintains consistency between consecutive action chunks
|
||||
- Reduces jitter and improves smoothness
|
||||
- Adapts to inference delays dynamically
|
||||
|
||||
**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
|
||||
## Scripts
|
||||
|
||||
### 1. `eval_dataset.py`
|
||||
|
||||
Offline evaluation on dataset samples with detailed visualization and validation.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Compare RTC vs non-RTC predictions on two random dataset samples
|
||||
- Validate RTC behavior (delay region, blend region, post-horizon region)
|
||||
- Generate debug visualizations:
|
||||
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
|
||||
- Final action predictions comparison
|
||||
- Support for torch.compile() optimization
|
||||
- Memory-efficient sequential policy loading for large models
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Basic usage with SmolVLA policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--seed=10
|
||||
|
||||
# With Pi0.5 policy on CUDA
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi05_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With Pi0 policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
--torch_compile_disable_cudagraphs=false
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--dataset.repo_id`: Dataset to evaluate on
|
||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
|
||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
|
||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||
- `--inference_delay`: Inference delay for RTC (default: 4)
|
||||
- `--seed`: Random seed for reproducibility (default: 42)
|
||||
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
|
||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||
|
||||
**Output:**
|
||||
|
||||
The script generates several visualization files in `rtc_debug_output/`:
|
||||
|
||||
- `denoising_xt_comparison.png` - Noisy state evolution during denoising
|
||||
- `denoising_vt_comparison.png` - Velocity predictions during denoising
|
||||
- `denoising_x1t_comparison.png` - Predicted final states during denoising
|
||||
- `denoising_correction_comparison.png` - RTC guidance corrections applied
|
||||
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
|
||||
|
||||
The script also validates RTC behavior and reports:
|
||||
|
||||
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
|
||||
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
|
||||
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
|
||||
|
||||
### 2. `eval_with_real_robot.py`
|
||||
|
||||
Real-time evaluation on physical robots or simulation environments.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Run policy with RTC on real robot or simulation
|
||||
- Multi-threaded action execution and inference
|
||||
- Action queue management with proper timing
|
||||
- Latency tracking and adaptive inference delay
|
||||
- Support for both robots and gym environments
|
||||
- Support for torch.compile() optimization
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# With real robot
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--task="pick up the cup" \
|
||||
--duration=30.0
|
||||
|
||||
# With simulation environment
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--env.type=pusht \
|
||||
--duration=60.0
|
||||
|
||||
# With policy compilation (CUDA only, not MPS)
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--robot.type` or `--env.type`: Robot or environment to use
|
||||
- `--task`: Task description (for VLA models)
|
||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
|
||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
|
||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||
- `--duration`: How long to run (seconds, default: 30.0)
|
||||
- `--fps`: Action execution frequency (Hz, default: 10.0)
|
||||
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
|
||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||
|
||||
## Understanding RTC Parameters
|
||||
|
||||
### `execution_horizon`
|
||||
|
||||
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
|
||||
|
||||
**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
|
||||
|
||||
### `max_guidance_weight`
|
||||
|
||||
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
|
||||
|
||||
**Typical values:**
|
||||
|
||||
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
|
||||
- Real-time execution: 1.0-10.0 (more conservative)
|
||||
|
||||
### `prefix_attention_schedule`
|
||||
|
||||
How to weight consistency across the overlap region:
|
||||
|
||||
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
|
||||
- `ONES`: Full weight across entire execution_horizon
|
||||
- `LINEAR`: Linear decay from inference_delay to execution_horizon
|
||||
- `EXP`: Exponential decay (recommended)
|
||||
|
||||
**Recommended:** `EXP`
|
||||
|
||||
### `inference_delay`
|
||||
|
||||
Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
|
||||
|
||||
**Typical values:** 3-5 steps for dataset evaluation
|
||||
|
||||
### `action_queue_size_to_get_new_actions` (real-time only)
|
||||
|
||||
Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
|
||||
|
||||
**Typical values:** 20-30 steps
|
||||
|
||||
## Validation Rules (Dataset Evaluation)
|
||||
|
||||
The dataset evaluation script validates that RTC behavior matches expectations:
|
||||
|
||||
1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
|
||||
- Ensures consistency during the inference delay period
|
||||
|
||||
2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
|
||||
- Smooth transition from previous plan to new predictions
|
||||
|
||||
3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
|
||||
- Full adoption of new predictions after execution horizon
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
|
||||
2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
|
||||
3. **Tune execution_horizon** based on your inference latency and action frequency
|
||||
4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
|
||||
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Validation fails in delay region
|
||||
|
||||
- Check that `prev_chunk_left_over` is properly passed to the policy
|
||||
- Verify RTC guidance is being applied during denoising
|
||||
- Look at denoising visualizations to see where guidance diverges
|
||||
|
||||
### Validation fails in post-horizon region
|
||||
|
||||
- RTC and no_rtc use different noise - verify same noise is being used for comparison
|
||||
- Check that weights are correctly zeroed out after execution horizon
|
||||
- Review prefix_attention_schedule visualization
|
||||
|
||||
### Poor performance on real robot
|
||||
|
||||
- Increase `action_queue_size_to_get_new_actions` if you see warnings
|
||||
- Reduce `max_guidance_weight` if robot is too conservative
|
||||
- Try different `prefix_attention_schedule` values
|
||||
- Enable torch.compile() for faster inference (CUDA only)
|
||||
|
||||
### Memory issues with large models
|
||||
|
||||
- The dataset evaluation script loads policies sequentially to minimize memory
|
||||
- For real-time execution, only one policy is loaded
|
||||
- Use smaller batch sizes if needed
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
|
||||
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
|
||||
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
|
||||
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to add profiling instrumentation to RTCProcessor.
|
||||
|
||||
This script shows which methods to profile in the RTC code to identify bottlenecks.
|
||||
You can either:
|
||||
1. Apply these changes directly to modeling_rtc.py
|
||||
2. Use monkey patching to add profiling without modifying source
|
||||
3. Use as reference for manual instrumentation
|
||||
|
||||
Usage:
|
||||
# Option 1: Monkey patch (no source changes)
|
||||
python examples/rtc/add_rtc_profiling.py
|
||||
|
||||
# Option 2: Apply changes to source
|
||||
# Copy the profiled methods below into src/lerobot/policies/rtc/modeling_rtc.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.profiling import ProfileContext, enable_profiling, is_profiling_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def profile_denoise_step(self, x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon=None) -> Tensor:
|
||||
"""Profiled version of denoise_step."""
|
||||
|
||||
if not is_profiling_enabled():
|
||||
# Call original implementation if profiling disabled
|
||||
return self._original_denoise_step(x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon)
|
||||
|
||||
with ProfileContext("rtc.denoise_step.total"):
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
with ProfileContext("rtc.denoise_step.setup"):
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
# Padding
|
||||
with ProfileContext("rtc.denoise_step.padding"):
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
# Get prefix weights
|
||||
with ProfileContext("rtc.denoise_step.get_prefix_weights"):
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# Main RTC guidance computation
|
||||
with ProfileContext("rtc.denoise_step.guidance_computation"):
|
||||
with torch.enable_grad():
|
||||
# Base denoising
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
# Compute x1_t
|
||||
with ProfileContext("rtc.denoise_step.compute_x1_t"):
|
||||
x1_t = x_t - time * v_t
|
||||
|
||||
# Compute error
|
||||
with ProfileContext("rtc.denoise_step.compute_error"):
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
|
||||
# Compute correction via autograd
|
||||
with ProfileContext("rtc.denoise_step.autograd_correction"):
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
# Compute guidance weight
|
||||
with ProfileContext("rtc.denoise_step.compute_guidance_weight"):
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
# Apply guidance
|
||||
with ProfileContext("rtc.denoise_step.apply_guidance"):
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Cleanup
|
||||
with ProfileContext("rtc.denoise_step.cleanup"):
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def monkey_patch_rtc_profiling():
|
||||
"""Apply profiling to RTCProcessor via monkey patching.
|
||||
|
||||
This modifies the RTCProcessor class at runtime to add profiling
|
||||
without changing source files.
|
||||
"""
|
||||
logger.info("Applying RTC profiling monkey patch...")
|
||||
|
||||
# Save original method
|
||||
RTCProcessor._original_denoise_step = RTCProcessor.denoise_step
|
||||
|
||||
# Replace with profiled version
|
||||
RTCProcessor.denoise_step = profile_denoise_step
|
||||
|
||||
logger.info("✓ RTC profiling enabled")
|
||||
|
||||
|
||||
def print_usage():
|
||||
"""Print usage instructions."""
|
||||
print("\n" + "="*80)
|
||||
print("RTC PROFILING INSTRUMENTATION")
|
||||
print("="*80)
|
||||
print("\nThis script provides profiling for RTCProcessor methods.")
|
||||
print("\nOption 1: Monkey Patch (Recommended)")
|
||||
print("-" * 40)
|
||||
print("Add to your script:")
|
||||
print("""
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# ... run your code ...
|
||||
|
||||
# Print results
|
||||
print_profiling_summary()
|
||||
""")
|
||||
|
||||
print("\nOption 2: Manual Source Modification")
|
||||
print("-" * 40)
|
||||
print("1. Copy profile_denoise_step() from this file")
|
||||
print("2. Replace denoise_step() in src/lerobot/policies/rtc/modeling_rtc.py")
|
||||
print("3. Add profiling imports at top of file")
|
||||
|
||||
print("\nKey Metrics to Watch:")
|
||||
print("-" * 40)
|
||||
print("- rtc.denoise_step.base_denoising - Time for base policy inference")
|
||||
print("- rtc.denoise_step.autograd_correction - Time computing gradients")
|
||||
print("- rtc.denoise_step.guidance_computation - Total guidance overhead")
|
||||
print("- rtc.denoise_step.get_prefix_weights - Time computing weights")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_usage()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,549 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -0,0 +1,631 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Profiled version of eval_with_real_robot.py for performance analysis.
|
||||
|
||||
This version adds detailed timing measurements for:
|
||||
- Policy inference
|
||||
- Preprocessing
|
||||
- Postprocessing
|
||||
- Action queue operations
|
||||
- Robot communication
|
||||
- Thread execution times
|
||||
|
||||
Usage: Same as eval_with_real_robot.py but with profiling output.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileTimer:
|
||||
"""Context manager and utility class for timing code sections."""
|
||||
|
||||
def __init__(self, name: str, stats_dict: dict):
|
||||
self.name = name
|
||||
self.stats_dict = stats_dict
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
elapsed = time.perf_counter() - self.start_time
|
||||
if self.name not in self.stats_dict:
|
||||
self.stats_dict[self.name] = []
|
||||
self.stats_dict[self.name].append(elapsed)
|
||||
|
||||
|
||||
class ProfilingStats:
|
||||
"""Global profiling statistics collector."""
|
||||
|
||||
def __init__(self):
|
||||
self.stats = defaultdict(list)
|
||||
self.lock = Lock()
|
||||
|
||||
def record(self, name: str, duration: float):
|
||||
with self.lock:
|
||||
self.stats[name].append(duration)
|
||||
|
||||
def timer(self, name: str):
|
||||
"""Return a context manager for timing."""
|
||||
return ProfileTimer(name, self.stats)
|
||||
|
||||
def get_summary(self) -> dict[str, dict[str, float]]:
|
||||
"""Get summary statistics for all timings."""
|
||||
with self.lock:
|
||||
summary = {}
|
||||
for name, times in self.stats.items():
|
||||
if times:
|
||||
summary[name] = {
|
||||
"count": len(times),
|
||||
"mean": sum(times) / len(times),
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"total": sum(times),
|
||||
}
|
||||
return summary
|
||||
|
||||
def print_summary(self):
|
||||
"""Print formatted summary of all timings."""
|
||||
summary = self.get_summary()
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("PROFILING SUMMARY")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Sort by total time (descending)
|
||||
sorted_items = sorted(summary.items(), key=lambda x: x[1]["total"], reverse=True)
|
||||
|
||||
for name, stats in sorted_items:
|
||||
logger.info(f"\n{name}:")
|
||||
logger.info(f" Count: {stats['count']}")
|
||||
logger.info(f" Mean: {stats['mean']*1000:.2f} ms")
|
||||
logger.info(f" Min: {stats['min']*1000:.2f} ms")
|
||||
logger.info(f" Max: {stats['max']*1000:.2f} ms")
|
||||
logger.info(f" Total: {stats['total']:.2f} s")
|
||||
logger.info(f" Hz: {stats['count']/stats['total']:.2f}")
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
|
||||
|
||||
# Global profiling stats
|
||||
profiling_stats = ProfilingStats()
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with profiling_stats.timer("robot.get_observation"):
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with profiling_stats.timer("robot.send_action"):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy with profiling.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
with profiling_stats.timer("get_actions.total_iteration"):
|
||||
inference_count += 1
|
||||
logger.info(f"[GET_ACTIONS] Starting inference #{inference_count}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
|
||||
with profiling_stats.timer("get_actions.get_prev_actions"):
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
# Get observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
with profiling_stats.timer("get_actions.robot_obs_processing"):
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
# Build dataset frame
|
||||
with profiling_stats.timer("get_actions.build_dataset_frame"):
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
with profiling_stats.timer("get_actions.tensor_conversion"):
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
# Preprocessing
|
||||
with profiling_stats.timer("get_actions.preprocessing"):
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Policy inference
|
||||
with profiling_stats.timer("get_actions.policy_inference"):
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Clone for RTC
|
||||
with profiling_stats.timer("get_actions.clone_actions"):
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
# Postprocessing
|
||||
with profiling_stats.timer("get_actions.postprocessing"):
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
# Update latency tracker
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
logger.info(
|
||||
f"[GET_ACTIONS] Inference #{inference_count} completed in {new_latency*1000:.2f}ms "
|
||||
f"(delay={new_delay} chunks)"
|
||||
)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, "
|
||||
"It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
# Merge into action queue
|
||||
with profiling_stats.timer("get_actions.action_queue_merge"):
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot with profiling.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with profiling_stats.timer("actor.total_iteration"):
|
||||
# Get action from queue
|
||||
with profiling_stats.timer("actor.queue_get"):
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
# Process action
|
||||
with profiling_stats.timer("actor.action_processing"):
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
|
||||
# Send to robot (includes robot.send_action timing)
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
# Sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, (action_interval - dt_s) - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with profiling."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
logger.info("=" * 80)
|
||||
logger.info("PROFILING MODE ENABLED")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processor
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
# Print profiling summary
|
||||
profiling_stats.print_summary()
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
|
||||
@@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Comprehensive profiling script for Pi0 with RTC.
|
||||
|
||||
This script demonstrates how to use all the profiling tools to identify
|
||||
bottlenecks in Pi0 policy inference with RTC enabled.
|
||||
|
||||
It profiles:
|
||||
1. Overall inference time
|
||||
2. RTC-specific operations (guidance, weights, etc.)
|
||||
3. Preprocessing/postprocessing
|
||||
4. Individual method timings
|
||||
|
||||
Usage:
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
ProfileContext,
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
# Import monkey patching for RTC profiling
|
||||
try:
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
except ImportError:
|
||||
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
|
||||
monkey_patch_rtc_profiling = None
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_observation(policy_config, device: str) -> dict:
|
||||
"""Create a mock observation matching policy requirements.
|
||||
|
||||
Args:
|
||||
policy_config: Policy configuration
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
obs = {}
|
||||
|
||||
# Create mock state observation
|
||||
state_dim = 10 # Typical robot state dimension
|
||||
obs["observation.state"] = torch.randn(1, state_dim, device=device)
|
||||
|
||||
# Create mock images if needed
|
||||
# For Pi0, we typically need at least one image
|
||||
image_height = 224
|
||||
image_width = 224
|
||||
|
||||
# Common image keys for Pi0
|
||||
image_keys = ["observation.images.gripper", "observation.images.front"]
|
||||
|
||||
for key in image_keys:
|
||||
# Images should be [B, C, H, W] and normalized to [0, 1]
|
||||
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
|
||||
|
||||
# Add task
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Add language tokens and attention mask (required for Pi0)
|
||||
# These are mock values - in real usage they come from tokenizer
|
||||
max_seq_len = 32
|
||||
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
|
||||
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_single_iteration(
|
||||
policy,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
observation: dict,
|
||||
prev_actions: torch.Tensor | None,
|
||||
use_rtc: bool,
|
||||
inference_delay: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
|
||||
"""Profile a single inference iteration.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
preprocessor: Observation preprocessor
|
||||
postprocessor: Action postprocessor
|
||||
observation: Input observation
|
||||
prev_actions: Previous action chunk (for RTC)
|
||||
use_rtc: Whether RTC is enabled
|
||||
inference_delay: Inference delay in timesteps
|
||||
|
||||
Returns:
|
||||
Tuple of (actions, new_prev_actions, timings)
|
||||
"""
|
||||
timings = {}
|
||||
|
||||
with ProfileContext("iteration.total"):
|
||||
# Preprocessing
|
||||
with ProfileContext("iteration.preprocessing"):
|
||||
preprocessed_obs = preprocessor(observation)
|
||||
|
||||
# Policy inference
|
||||
with ProfileContext("iteration.policy_inference"):
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
else:
|
||||
actions = policy.predict_action_chunk(preprocessed_obs)
|
||||
|
||||
# Clone for next iteration (if RTC)
|
||||
new_prev_actions = None
|
||||
if use_rtc:
|
||||
with ProfileContext("iteration.prepare_prev_actions"):
|
||||
execution_horizon = policy.config.rtc_config.execution_horizon
|
||||
if actions.shape[1] > execution_horizon:
|
||||
new_prev_actions = actions[:, execution_horizon:].clone()
|
||||
|
||||
# Postprocessing
|
||||
with ProfileContext("iteration.postprocessing"):
|
||||
processed_actions = postprocessor(actions)
|
||||
|
||||
return processed_actions, new_prev_actions, timings
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
|
||||
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
|
||||
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
|
||||
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
|
||||
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
|
||||
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("DETAILED PI0 RTC PROFILING")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Policy: {args.policy_path}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Iterations: {args.num_iterations}")
|
||||
logger.info(f"Execution Horizon: {args.execution_horizon}")
|
||||
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Apply RTC profiling if requested
|
||||
if args.enable_rtc_profiling:
|
||||
if monkey_patch_rtc_profiling is not None:
|
||||
monkey_patch_rtc_profiling()
|
||||
logger.info("✓ Detailed RTC profiling enabled\n")
|
||||
else:
|
||||
logger.warning("⚠ Could not enable detailed RTC profiling\n")
|
||||
|
||||
# Load policy
|
||||
logger.info("Loading policy...")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy_class = get_policy_class(config.type)
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Configure RTC
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"✓ Policy loaded: {config.type}\n")
|
||||
|
||||
# Create preprocessor and postprocessor
|
||||
logger.info("Loading preprocessor/postprocessor...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=config,
|
||||
pretrained_path=args.policy_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": args.device},
|
||||
},
|
||||
)
|
||||
logger.info("✓ Preprocessor/postprocessor loaded\n")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(config, args.device)
|
||||
logger.info("✓ Mock observation created\n")
|
||||
|
||||
# Warmup
|
||||
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
for i in range(args.warmup_iterations):
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Clear warmup stats
|
||||
clear_profiling_stats()
|
||||
logger.info("✓ Warmup complete\n")
|
||||
|
||||
# Profiled run WITH RTC
|
||||
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
iteration_times = []
|
||||
|
||||
for i in range(args.num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Sync CUDA if needed
|
||||
if args.device.startswith("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
iteration_times.append(elapsed)
|
||||
|
||||
if (i + 1) % 5 == 0:
|
||||
logger.info(f" Completed {i+1}/{args.num_iterations}")
|
||||
|
||||
logger.info("✓ Profiling complete\n")
|
||||
|
||||
# Print summary statistics
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("ITERATION TIMING SUMMARY")
|
||||
logger.info("="*80)
|
||||
|
||||
times_arr = np.array(iteration_times)
|
||||
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
|
||||
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Print detailed profiling breakdown
|
||||
print_profiling_summary(sort_by="total")
|
||||
|
||||
# Print key insights
|
||||
stats = get_profiling_stats()
|
||||
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("KEY INSIGHTS")
|
||||
logger.info("="*80)
|
||||
|
||||
# Find bottlenecks
|
||||
if stats:
|
||||
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
|
||||
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
|
||||
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
|
||||
|
||||
total_time = policy_inference_time + preprocessing_time + postprocessing_time
|
||||
|
||||
if total_time > 0:
|
||||
logger.info(f"\nTime breakdown:")
|
||||
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
|
||||
|
||||
# RTC-specific insights
|
||||
if args.enable_rtc_profiling:
|
||||
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
|
||||
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
|
||||
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
|
||||
|
||||
if rtc_guidance > 0:
|
||||
logger.info(f"\nRTC breakdown:")
|
||||
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
|
||||
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
|
||||
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
|
||||
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
|
||||
|
||||
# Recommendations
|
||||
logger.info("\nRecommendations:")
|
||||
|
||||
if preprocessing_time > policy_inference_time * 0.3:
|
||||
logger.info(" ⚠ Preprocessing is taking >30% of time")
|
||||
logger.info(" → Consider reducing image resolution")
|
||||
logger.info(" → Consider using fewer cameras")
|
||||
|
||||
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
|
||||
logger.info(" ⚠ RTC autograd overhead is significant")
|
||||
logger.info(" → This is expected, but consider increasing execution_horizon")
|
||||
logger.info(" → Try torch.compile if not already enabled")
|
||||
|
||||
if not args.use_torch_compile:
|
||||
logger.info(" 💡 torch.compile not enabled")
|
||||
logger.info(" → Try --use_torch_compile for potential speedup")
|
||||
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nProfiling interrupted by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"\n\nError during profiling: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to compare performance with and without RTC enabled.
|
||||
|
||||
This script helps identify whether RTC is actually improving or degrading performance
|
||||
by running multiple inference passes and collecting detailed timing statistics.
|
||||
|
||||
Usage:
|
||||
# Profile with mock data (no robot needed)
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50
|
||||
|
||||
# Profile with specific RTC config
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileResults:
|
||||
"""Results from profiling run."""
|
||||
|
||||
mode: str # "with_rtc" or "without_rtc"
|
||||
mean_time: float
|
||||
std_time: float
|
||||
min_time: float
|
||||
max_time: float
|
||||
times: list[float]
|
||||
throughput: float # iterations per second
|
||||
|
||||
|
||||
def create_mock_observation(policy, device: str) -> dict:
|
||||
"""Create a mock observation for testing.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
# Get expected input shapes from policy config
|
||||
# This is a simplified version - adjust based on actual policy requirements
|
||||
obs = {}
|
||||
|
||||
# Mock image observations (if needed)
|
||||
if hasattr(policy.config, "input_shapes"):
|
||||
for key, shape in policy.config.input_shapes.items():
|
||||
if "image" in key:
|
||||
# Typical image shape: (batch, channels, height, width)
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
else:
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
|
||||
# Add task if needed
|
||||
if "task" in policy.config.__dict__ or hasattr(policy, "accepts_task"):
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Mock state observation
|
||||
obs["observation.state"] = torch.randn(1, 10, device=device) # Adjust size as needed
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_inference(
|
||||
policy, observation: dict, num_iterations: int, use_rtc: bool, execution_horizon: int = 10
|
||||
) -> ProfileResults:
|
||||
"""Profile policy inference with or without RTC.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
observation: Observation dictionary
|
||||
num_iterations: Number of inference iterations to run
|
||||
use_rtc: Whether to enable RTC
|
||||
execution_horizon: Execution horizon for RTC
|
||||
|
||||
Returns:
|
||||
ProfileResults with timing statistics
|
||||
"""
|
||||
mode = "with_rtc" if use_rtc else "without_rtc"
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info(f"Profiling: {mode.upper()}")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
# Configure RTC
|
||||
if use_rtc:
|
||||
policy.config.rtc_config.enabled = True
|
||||
policy.config.rtc_config.execution_horizon = execution_horizon
|
||||
policy.init_rtc_processor()
|
||||
else:
|
||||
policy.config.rtc_config.enabled = False
|
||||
|
||||
times = []
|
||||
prev_actions = None
|
||||
|
||||
# Warmup
|
||||
logger.info("Warming up (5 iterations)...")
|
||||
for _ in range(5):
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
_ = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
_ = policy.predict_action_chunk(observation)
|
||||
|
||||
# Actual profiling
|
||||
logger.info(f"Running {num_iterations} profiled iterations...")
|
||||
for i in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
# Simulate consuming some actions for next iteration
|
||||
if actions.shape[1] > execution_horizon:
|
||||
prev_actions = actions[:, execution_horizon:].clone()
|
||||
else:
|
||||
prev_actions = None
|
||||
else:
|
||||
actions = policy.predict_action_chunk(observation)
|
||||
|
||||
# Synchronize if using CUDA
|
||||
if observation["observation.state"].device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f" Completed {i+1}/{num_iterations} iterations")
|
||||
|
||||
# Calculate statistics
|
||||
times_arr = np.array(times)
|
||||
results = ProfileResults(
|
||||
mode=mode,
|
||||
mean_time=float(np.mean(times_arr)),
|
||||
std_time=float(np.std(times_arr)),
|
||||
min_time=float(np.min(times_arr)),
|
||||
max_time=float(np.max(times_arr)),
|
||||
times=times,
|
||||
throughput=num_iterations / sum(times),
|
||||
)
|
||||
|
||||
logger.info(f"\nResults for {mode}:")
|
||||
logger.info(f" Mean time: {results.mean_time*1000:.2f} ms")
|
||||
logger.info(f" Std dev: {results.std_time*1000:.2f} ms")
|
||||
logger.info(f" Min time: {results.min_time*1000:.2f} ms")
|
||||
logger.info(f" Max time: {results.max_time*1000:.2f} ms")
|
||||
logger.info(f" Throughput: {results.throughput:.2f} iter/s")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_results(results_without_rtc: ProfileResults, results_with_rtc: ProfileResults):
|
||||
"""Compare and print results from both runs.
|
||||
|
||||
Args:
|
||||
results_without_rtc: Results from run without RTC
|
||||
results_with_rtc: Results from run with RTC
|
||||
"""
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("COMPARISON SUMMARY")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
mean_diff = results_with_rtc.mean_time - results_without_rtc.mean_time
|
||||
mean_diff_pct = (mean_diff / results_without_rtc.mean_time) * 100
|
||||
|
||||
throughput_diff = results_with_rtc.throughput - results_without_rtc.throughput
|
||||
throughput_diff_pct = (throughput_diff / results_without_rtc.throughput) * 100
|
||||
|
||||
logger.info(f"\n{'Metric':<30} {'Without RTC':>15} {'With RTC':>15} {'Difference':>15}")
|
||||
logger.info("-" * 80)
|
||||
logger.info(
|
||||
f"{'Mean time (ms)':<30} "
|
||||
f"{results_without_rtc.mean_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.mean_time*1000:>15.2f} "
|
||||
f"{mean_diff*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Std dev (ms)':<30} "
|
||||
f"{results_without_rtc.std_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.std_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.std_time - results_without_rtc.std_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Min time (ms)':<30} "
|
||||
f"{results_without_rtc.min_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.min_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.min_time - results_without_rtc.min_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Max time (ms)':<30} "
|
||||
f"{results_without_rtc.max_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.max_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.max_time - results_without_rtc.max_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Throughput (iter/s)':<30} "
|
||||
f"{results_without_rtc.throughput:>15.2f} "
|
||||
f"{results_with_rtc.throughput:>15.2f} "
|
||||
f"{throughput_diff:>+15.2f}"
|
||||
)
|
||||
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("VERDICT")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
if mean_diff_pct < -5:
|
||||
logger.info(f"✓ RTC is FASTER by {abs(mean_diff_pct):.1f}%")
|
||||
logger.info(f" Mean time reduced by {abs(mean_diff)*1000:.2f} ms")
|
||||
elif mean_diff_pct > 5:
|
||||
logger.info(f"✗ RTC is SLOWER by {mean_diff_pct:.1f}%")
|
||||
logger.info(f" Mean time increased by {mean_diff*1000:.2f} ms")
|
||||
logger.info("\n Possible reasons:")
|
||||
logger.info(" - RTC overhead exceeds benefits at current execution horizon")
|
||||
logger.info(" - Inference delay calculation not accounting for RTC processing")
|
||||
logger.info(" - Additional tensor operations in RTC guidance")
|
||||
else:
|
||||
logger.info(f"≈ Performance is SIMILAR (difference: {mean_diff_pct:+.1f}%)")
|
||||
|
||||
logger.info(f"{'='*80}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile RTC performance")
|
||||
parser.add_argument(
|
||||
"--policy_path", type=str, required=True, help="Path to pretrained policy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="Device to run on (cuda/cpu/mps)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iterations", type=int, default=50, help="Number of inference iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--execution_horizon", type=int, default=10, help="RTC execution horizon"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_detailed_profiling",
|
||||
action="store_true",
|
||||
help="Enable detailed method-level profiling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_torch_compile", action="store_true", help="Use torch.compile for faster inference"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from {args.policy_path}")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
policy_class = get_policy_class(config.type)
|
||||
|
||||
# Set compile flag if needed
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Initialize RTC config
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"Policy loaded: {config.type}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Execution horizon: {args.execution_horizon}")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(policy, args.device)
|
||||
|
||||
# Enable detailed profiling if requested
|
||||
if args.enable_detailed_profiling:
|
||||
enable_profiling()
|
||||
logger.info("Detailed profiling enabled")
|
||||
|
||||
# Profile without RTC
|
||||
results_without_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=False,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITHOUT RTC):")
|
||||
print_profiling_summary()
|
||||
clear_profiling_stats()
|
||||
|
||||
# Profile with RTC
|
||||
results_with_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITH RTC):")
|
||||
print_profiling_summary()
|
||||
|
||||
# Compare results
|
||||
compare_results(results_without_rtc, results_with_rtc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+2
-1
@@ -98,6 +98,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -132,7 +133,7 @@ groot = [
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
|
||||
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
class RTCAttentionSchedule(str, Enum):
|
||||
ZEROS = "ZEROS"
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI0 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI0Config):
|
||||
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI0 model
|
||||
self.model = PI0Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
# Sample actions using the model (pass through RTC kwargs)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config):
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI05 model
|
||||
self.model = PI05Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# Real-Time Chunking (RTC) Module
|
||||
|
||||
This module implements Real-Time Chunking and related adaptive inference techniques for robotics policies in LeRobot.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-Time Chunking (RTC) addresses the challenge of real-time inference in action chunking policies by treating chunk generation as an inpainting problem. It strategically handles overlapping timesteps between action chunks using prefix attention mechanisms.
|
||||
|
||||
It is particularly effective for handling long-horizon inference in robotics policies.
|
||||
|
||||
## Integration with Policies
|
||||
|
||||
RTC can be integrated with any policy that supports flow mathicng for chunking:
|
||||
|
||||
- **SmolVLA**: Vision-language-action model with RTC support
|
||||
- **Pi0**: Action prediction model with adaptive chunking
|
||||
- **Pi05**: Action prediction model with adaptive chunking
|
||||
|
||||
## Original Implementation
|
||||
|
||||
This implementation is based on Physical Intelligence's Kinetix RTC:
|
||||
|
||||
- [Original RTC implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214)
|
||||
- [Kinetix GitHub Repository](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
|
||||
|
||||
## References
|
||||
|
||||
- [Real Time Chunking Paper](https://www.physicalintelligence.company/research/real_time_chunking)
|
||||
- [Physical Intelligence Kinetix](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
|
||||
|
||||
## How to run
|
||||
|
||||
### Check with data from the dataset
|
||||
|
||||
```bash
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
|
||||
|
||||
The example output should look like this:
|
||||

|
||||
|
||||
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action queue management for Real-Time Chunking (RTC).
|
||||
|
||||
This module provides ActionQueue, a thread-safe queue for managing action chunks
|
||||
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
|
||||
handling action merging and leftover tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
"""Thread-safe queue for managing action chunks in real-time control.
|
||||
|
||||
This queue handles two types of action sequences:
|
||||
- Original actions: Used for RTC to compute leftovers from previous chunks
|
||||
- Processed actions: Post-processed actions ready for robot execution
|
||||
|
||||
The queue operates in two modes:
|
||||
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
|
||||
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
|
||||
|
||||
Attributes:
|
||||
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
|
||||
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
|
||||
last_index (int): Current consumption index in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
"""Initialize the action queue.
|
||||
|
||||
Args:
|
||||
cfg: RTC configuration controlling queue behavior.
|
||||
"""
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next action from the queue.
|
||||
|
||||
Returns:
|
||||
Tensor | None: The next action (action_dim,) or None if queue is empty.
|
||||
Returns a clone to prevent external modifications.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Get the number of remaining actions in the queue.
|
||||
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
These are the unconsumed actions from the current chunk, which will be
|
||||
used by RTC to compute corrections for the next chunk.
|
||||
|
||||
Returns:
|
||||
Tensor | None: Remaining original actions (remaining_steps, action_dim),
|
||||
or None if no original queue exists.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
"""Merge new actions into the queue.
|
||||
|
||||
This method operates differently based on RTC mode:
|
||||
- RTC enabled: Replaces the queue, accounting for inference delay
|
||||
- RTC disabled: Appends to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy (time_steps, action_dim).
|
||||
processed_actions: Post-processed actions for robot (time_steps, action_dim).
|
||||
real_delay: Number of time steps of inference delay.
|
||||
action_index_before_inference: Index before inference started, for validation.
|
||||
"""
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
"""Replace the queue with new actions (RTC mode).
|
||||
|
||||
Discards the first `real_delay` actions since they correspond to the time
|
||||
spent during inference, when the robot was executing previous actions.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
real_delay: Number of time steps to skip due to inference delay.
|
||||
"""
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.debug(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
"""Append new actions to the queue (non-RTC mode).
|
||||
|
||||
Removes already-consumed actions and appends new ones, maintaining
|
||||
queue continuity without replacement.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
"""
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
"""Validate that computed delays match expectations.
|
||||
|
||||
Compares the delay computed from inference latency with the actual
|
||||
number of actions consumed during inference.
|
||||
|
||||
Args:
|
||||
real_delay: Delay computed from inference latency.
|
||||
action_index_before_inference: Action index when inference started.
|
||||
"""
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as delay calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
|
||||
|
||||
Based on:
|
||||
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCConfig:
|
||||
"""Configuration for Real Time Chunking (RTC) inference.
|
||||
|
||||
RTC improves real-time inference by treating chunk generation as an inpainting problem,
|
||||
strategically handling overlapping timesteps between action chunks using prefix attention.
|
||||
"""
|
||||
|
||||
# Infrastructure
|
||||
enabled: bool = False
|
||||
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
|
||||
max_guidance_weight: float = 10.0
|
||||
execution_horizon: int = 10
|
||||
|
||||
# Debug settings
|
||||
debug: bool = False
|
||||
debug_maxlen: int = 100
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate RTC configuration parameters."""
|
||||
if self.max_guidance_weight <= 0:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Debug information handler for Real-Time Chunking (RTC)."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugStep:
|
||||
"""Container for debug information from a single denoising step.
|
||||
|
||||
Attributes:
|
||||
step_idx (int): Step index/counter.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
time (float | Tensor | None): Time parameter.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
metadata (dict[str, Any]): Additional metadata.
|
||||
"""
|
||||
|
||||
step_idx: int = 0
|
||||
x_t: Tensor | None = None
|
||||
v_t: Tensor | None = None
|
||||
x1_t: Tensor | None = None
|
||||
correction: Tensor | None = None
|
||||
err: Tensor | None = None
|
||||
weights: Tensor | None = None
|
||||
guidance_weight: float | Tensor | None = None
|
||||
time: float | Tensor | None = None
|
||||
inference_delay: int | None = None
|
||||
execution_horizon: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
|
||||
"""Convert debug step to dictionary.
|
||||
|
||||
Args:
|
||||
include_tensors (bool): If True, include tensor values. If False, only include
|
||||
tensor statistics (shape, mean, std, min, max).
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the debug step.
|
||||
"""
|
||||
result = {
|
||||
"step_idx": self.step_idx,
|
||||
"guidance_weight": (
|
||||
self.guidance_weight.item()
|
||||
if isinstance(self.guidance_weight, Tensor)
|
||||
else self.guidance_weight
|
||||
),
|
||||
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
|
||||
"inference_delay": self.inference_delay,
|
||||
"execution_horizon": self.execution_horizon,
|
||||
"metadata": self.metadata.copy(),
|
||||
}
|
||||
|
||||
# Add tensor information
|
||||
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
|
||||
for field_name in tensor_fields:
|
||||
tensor = getattr(self, field_name)
|
||||
if tensor is not None:
|
||||
if include_tensors:
|
||||
result[field_name] = tensor.detach().cpu()
|
||||
else:
|
||||
result[f"{field_name}_stats"] = {
|
||||
"shape": tuple(tensor.shape),
|
||||
"mean": tensor.mean().item(),
|
||||
"std": tensor.std().item(),
|
||||
"min": tensor.min().item(),
|
||||
"max": tensor.max().item(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Collects and manages debug information for RTC processing.
|
||||
|
||||
This tracker stores debug information from recent denoising steps in a dictionary,
|
||||
using time as the key for efficient lookups and updates.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether debug collection is enabled.
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False, maxlen: int = 100):
|
||||
self.enabled = enabled
|
||||
self._steps = {} if enabled else None # Dictionary with time as key
|
||||
self._maxlen = maxlen
|
||||
self._step_counter = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded debug information."""
|
||||
if self.enabled and self._steps is not None:
|
||||
self._steps.clear()
|
||||
self._step_counter = 0
|
||||
|
||||
@torch._dynamo.disable
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Track debug information for a denoising step at a given time.
|
||||
|
||||
If a step with the given time already exists, it will be updated with the new data.
|
||||
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
||||
|
||||
Note: This method is excluded from torch.compile to avoid graph breaks from
|
||||
operations like .item() which are incompatible with compiled graphs.
|
||||
|
||||
Args:
|
||||
time (float | Tensor): Time parameter - used as the key to identify the step.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction.
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
**metadata: Additional metadata to store.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Convert time to float and round to avoid float precision issues
|
||||
time_value = time.item() if isinstance(time, Tensor) else time
|
||||
time_key = round(time_value, 6) # Use rounded time as dictionary key
|
||||
|
||||
# Check if step with this time already exists
|
||||
if time_key in self._steps:
|
||||
# Update existing step with non-None fields
|
||||
existing_step = self._steps[time_key]
|
||||
if x_t is not None:
|
||||
existing_step.x_t = x_t.detach().clone()
|
||||
if v_t is not None:
|
||||
existing_step.v_t = v_t.detach().clone()
|
||||
if x1_t is not None:
|
||||
existing_step.x1_t = x1_t.detach().clone()
|
||||
if correction is not None:
|
||||
existing_step.correction = correction.detach().clone()
|
||||
if err is not None:
|
||||
existing_step.err = err.detach().clone()
|
||||
if weights is not None:
|
||||
existing_step.weights = weights.detach().clone()
|
||||
if guidance_weight is not None:
|
||||
existing_step.guidance_weight = guidance_weight
|
||||
if inference_delay is not None:
|
||||
existing_step.inference_delay = inference_delay
|
||||
if execution_horizon is not None:
|
||||
existing_step.execution_horizon = execution_horizon
|
||||
if metadata:
|
||||
existing_step.metadata.update(metadata)
|
||||
else:
|
||||
# Create new step
|
||||
step = DebugStep(
|
||||
step_idx=self._step_counter,
|
||||
x_t=x_t.detach().clone() if x_t is not None else None,
|
||||
v_t=v_t.detach().clone() if v_t is not None else None,
|
||||
x1_t=x1_t.detach().clone() if x1_t is not None else None,
|
||||
correction=correction.detach().clone() if correction is not None else None,
|
||||
err=err.detach().clone() if err is not None else None,
|
||||
weights=weights.detach().clone() if weights is not None else None,
|
||||
guidance_weight=guidance_weight,
|
||||
time=time_value,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Add to dictionary
|
||||
self._steps[time_key] = step
|
||||
self._step_counter += 1
|
||||
|
||||
# Enforce maxlen if set
|
||||
if self._maxlen is not None and len(self._steps) > self._maxlen:
|
||||
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
|
||||
oldest_key = next(iter(self._steps))
|
||||
del self._steps[oldest_key]
|
||||
|
||||
def get_all_steps(self) -> list[DebugStep]:
|
||||
"""Get all recorded debug steps.
|
||||
|
||||
Returns:
|
||||
List of all DebugStep objects (may be empty if disabled).
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return []
|
||||
|
||||
return list(self._steps.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of recorded debug steps."""
|
||||
if not self.enabled or self._steps is None:
|
||||
return 0
|
||||
return len(self._steps)
|
||||
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization utilities for RTC debug information."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RTCDebugVisualizer:
|
||||
"""Visualizer for RTC debug information.
|
||||
|
||||
This class provides methods to visualize debug information collected by the Tracker,
|
||||
including corrections, errors, weights, and guidance weights over denoising steps.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def plot_waypoints(
|
||||
axes,
|
||||
tensor,
|
||||
start_from: int = 0,
|
||||
color: str = "blue",
|
||||
label: str = "",
|
||||
alpha: float = 0.7,
|
||||
linewidth: float = 2,
|
||||
marker: str | None = None,
|
||||
markersize: int = 4,
|
||||
):
|
||||
"""Plot trajectories across multiple dimensions.
|
||||
|
||||
This function plots a tensor's values across time for multiple dimensions,
|
||||
with each dimension plotted on a separate axis.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes (one for each dimension).
|
||||
tensor: The tensor to plot (can be torch.Tensor or numpy array).
|
||||
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
|
||||
start_from: Starting index for the x-axis.
|
||||
color: Color for the plot lines.
|
||||
label: Label for the plot legend.
|
||||
alpha: Transparency level for the plot.
|
||||
linewidth: Width of the plot lines.
|
||||
marker: Marker style for data points (e.g., 'o', 's', '^').
|
||||
markersize: Size of the markers.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Handle None tensor
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# Convert tensor to numpy if needed
|
||||
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
|
||||
|
||||
# Handle different tensor shapes
|
||||
if tensor_np.ndim == 3:
|
||||
# If batch dimension present, take first batch
|
||||
tensor_np = tensor_np[0]
|
||||
elif tensor_np.ndim == 1:
|
||||
# If 1D, reshape to (time_steps, 1)
|
||||
tensor_np = tensor_np.reshape(-1, 1)
|
||||
|
||||
# Get dimensions
|
||||
time_steps, num_dims = tensor_np.shape
|
||||
|
||||
# Create x-axis indices
|
||||
x_indices = np.arange(start_from, start_from + time_steps)
|
||||
|
||||
# Plot each dimension on its corresponding axis
|
||||
num_axes = len(axes) if hasattr(axes, "__len__") else 1
|
||||
for dim_idx in range(min(num_dims, num_axes)):
|
||||
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
|
||||
|
||||
# Plot the trajectory
|
||||
if marker:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
marker=marker,
|
||||
markersize=markersize,
|
||||
)
|
||||
else:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
)
|
||||
|
||||
# Add grid and labels if not already present
|
||||
if not ax.xaxis.get_label().get_text():
|
||||
ax.set_xlabel("Step", fontsize=10)
|
||||
if not ax.yaxis.get_label().get_text():
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Add legend if label provided and this is the first dimension
|
||||
if label and dim_idx == 0:
|
||||
ax.legend(loc="best", fontsize=8)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LatencyTracker:
|
||||
"""Tracks recent latencies and provides max/percentile queries.
|
||||
|
||||
Args:
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int = 100):
|
||||
self._values = deque(maxlen=maxlen)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded latencies."""
|
||||
self._values.clear()
|
||||
self.max_latency = 0.0
|
||||
|
||||
def add(self, latency: float) -> None:
|
||||
"""Add a latency sample (seconds)."""
|
||||
# Ensure numeric and non-negative
|
||||
val = float(latency)
|
||||
|
||||
if val < 0:
|
||||
return
|
||||
self._values.append(val)
|
||||
self.max_latency = max(self.max_latency, val)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._values)
|
||||
|
||||
def max(self) -> float | None:
|
||||
"""Return the maximum latency or None if empty."""
|
||||
return self.max_latency
|
||||
|
||||
def percentile(self, q: float) -> float | None:
|
||||
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
q = float(q)
|
||||
if q <= 0.0:
|
||||
return min(self._values)
|
||||
if q >= 1.0:
|
||||
return self.max_latency
|
||||
vals = np.array(list(self._values), dtype=np.float32)
|
||||
return float(np.quantile(vals, q))
|
||||
|
||||
def p95(self) -> float | None:
|
||||
"""Return the 95th percentile latency or None if empty."""
|
||||
return self.percentile(0.95)
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real-Time Chunking (RTC) implementation for LeRobot.
|
||||
|
||||
Based on Physical Intelligence's Kinetix implementation:
|
||||
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_tracker import Tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RTCProcessor:
|
||||
"""Real-Time Chunking processor for action chunking policies.
|
||||
|
||||
This class implements RTC techniques including velocity calculation,
|
||||
prefix attention, and adaptive chunk processing.
|
||||
"""
|
||||
|
||||
def __init__(self, rtc_config: RTCConfig):
|
||||
self.rtc_config = rtc_config
|
||||
|
||||
self.tracker = None
|
||||
|
||||
if rtc_config.debug:
|
||||
self.tracker = Tracker(
|
||||
enabled=rtc_config.debug,
|
||||
maxlen=rtc_config.debug_maxlen,
|
||||
)
|
||||
|
||||
# ====================== Tracker Proxy Methods ======================
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Proxy method to track debug information.
|
||||
|
||||
If tracker is None or disabled, this method does nothing.
|
||||
Otherwise, it forwards the call to tracker.track().
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.track(
|
||||
time=time,
|
||||
x_t=x_t,
|
||||
v_t=v_t,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
**metadata,
|
||||
)
|
||||
|
||||
def get_all_debug_steps(self) -> list:
|
||||
"""Get all debug steps from tracker.
|
||||
|
||||
Returns empty list if tracker is disabled or None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
return self.tracker.get_all_steps()
|
||||
return []
|
||||
|
||||
def is_debug_enabled(self) -> bool:
|
||||
"""Check if debug tracking is enabled.
|
||||
|
||||
Returns True if tracker exists and is enabled.
|
||||
"""
|
||||
return self.tracker is not None and self.tracker.enabled
|
||||
|
||||
def reset_tracker(self) -> None:
|
||||
"""Reset the tracker, clearing all recorded steps.
|
||||
|
||||
Does nothing if tracker is None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.reset()
|
||||
|
||||
# ====================== End Tracker Proxy Methods ======================
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
x_t,
|
||||
prev_chunk_left_over,
|
||||
inference_delay,
|
||||
time,
|
||||
original_denoise_step_partial,
|
||||
execution_horizon=None,
|
||||
) -> Tensor:
|
||||
"""RTC guidance wrapper around an existing denoiser.
|
||||
|
||||
This method wraps an original denoising callable that only takes ``x_t`` and
|
||||
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
|
||||
(RTC) prefix guidance using the leftover prefix from the previous chunk.
|
||||
|
||||
Args:
|
||||
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
|
||||
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
|
||||
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
|
||||
is applied and the method returns ``v_t`` from the original denoiser.
|
||||
inference_delay (int): Number of timesteps from the prefix to use for guidance.
|
||||
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
|
||||
broadcastable with ``x_t``.
|
||||
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
|
||||
computes the base denoised velocity given only ``x_t``.
|
||||
execution_horizon (int | None): Horizon used to build prefix weights. If
|
||||
``None``, defaults to ``self.rtc_config.execution_horizon``.
|
||||
|
||||
Returns:
|
||||
Tensor: Guided velocity with the same shape as ``v_t``.
|
||||
|
||||
Notes:
|
||||
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
|
||||
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
|
||||
right-padded with zeros to match ``T``.
|
||||
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
|
||||
and broadcast to ``(B, T, A)``.
|
||||
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
|
||||
``error = (prev_chunk_left_over - x1_t) * weights``.
|
||||
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
|
||||
|
||||
Reference:
|
||||
https://www.physicalintelligence.company/download/real_time_chunking.pdf
|
||||
"""
|
||||
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
# Add batch dimension
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
# Add batch dimension
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
|
||||
# because there is nothing to merge
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
assert prev_chunk_left_over.shape == x_t.shape, (
|
||||
"The padded previous chunk must be the same size as the input tensor"
|
||||
)
|
||||
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
with torch.enable_grad():
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
x1_t = x_t - time * v_t # noqa: N806
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Remove the batch dimension if it was added
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_prefix_weights(self, start, end, total):
|
||||
start = min(start, end)
|
||||
|
||||
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
|
||||
weights = torch.zeros(total)
|
||||
weights[:start] = 1.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
|
||||
weights = torch.ones(total)
|
||||
weights[end:] = 0.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
|
||||
return weights
|
||||
|
||||
def _linweights(self, start, end, total):
|
||||
skip_steps_at_end = max(total - end, 0)
|
||||
|
||||
linspace_steps = total - skip_steps_at_end - start
|
||||
|
||||
if end <= start or linspace_steps <= 0:
|
||||
return torch.tensor([])
|
||||
|
||||
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
|
||||
|
||||
def _add_trailing_zeros(self, weights, total, end):
|
||||
zeros_len = total - end
|
||||
|
||||
if zeros_len <= 0:
|
||||
return weights
|
||||
|
||||
zeros = torch.zeros(zeros_len)
|
||||
return torch.cat([weights, zeros])
|
||||
|
||||
def _add_leading_ones(self, weights, start, total):
|
||||
ones_len = min(start, total)
|
||||
|
||||
if ones_len <= 0:
|
||||
return weights
|
||||
|
||||
ones = torch.ones(ones_len)
|
||||
return torch.cat([ones, weights])
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Lets create processor if the config provided
|
||||
# If RTC is not enabled - we still can track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
# In case of calling init_rtc_processor after the model is created
|
||||
# We need to set the rtc_processor to the model
|
||||
# During the normal initialization process the model is not created yet
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def _get_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
# TODO: Check if this for loop is needed.
|
||||
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||
# In the case of offline inference, we have the action in the batch
|
||||
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def predict_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
actions = self._get_action_chunk(batch, noise, **kwargs)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
if self._check_get_actions_condition():
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
|
||||
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def _check_get_actions_condition(self) -> bool:
|
||||
return len(self._queues[ACTION]) == 0
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: SmolVLAConfig):
|
||||
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
device=self.config.device,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Profiling utilities for performance analysis.
|
||||
|
||||
Usage:
|
||||
from lerobot.utils.profiling import profile_method, get_profiling_stats, print_profiling_summary
|
||||
|
||||
@profile_method
|
||||
def my_slow_function(x):
|
||||
return x * 2
|
||||
|
||||
# At end of execution:
|
||||
print_profiling_summary()
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global profiling statistics storage
|
||||
_profiling_stats: dict[str, list[float]] = defaultdict(list)
|
||||
_profiling_lock = Lock()
|
||||
_profiling_enabled = False
|
||||
|
||||
|
||||
def enable_profiling():
|
||||
"""Enable profiling globally."""
|
||||
global _profiling_enabled
|
||||
_profiling_enabled = True
|
||||
logger.info("Profiling enabled")
|
||||
|
||||
|
||||
def disable_profiling():
|
||||
"""Disable profiling globally."""
|
||||
global _profiling_enabled
|
||||
_profiling_enabled = False
|
||||
logger.info("Profiling disabled")
|
||||
|
||||
|
||||
def is_profiling_enabled() -> bool:
|
||||
"""Check if profiling is enabled."""
|
||||
return _profiling_enabled
|
||||
|
||||
|
||||
def record_timing(name: str, duration: float):
|
||||
"""Record a timing measurement.
|
||||
|
||||
Args:
|
||||
name: Name/identifier for this timing
|
||||
duration: Duration in seconds
|
||||
"""
|
||||
if not _profiling_enabled:
|
||||
return
|
||||
|
||||
with _profiling_lock:
|
||||
_profiling_stats[name].append(duration)
|
||||
|
||||
|
||||
def profile_method(func: Callable) -> Callable:
|
||||
"""Decorator to profile a method or function.
|
||||
|
||||
Args:
|
||||
func: Function to profile
|
||||
|
||||
Returns:
|
||||
Wrapped function that records execution time
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not _profiling_enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
duration = time.perf_counter() - start
|
||||
# Use fully qualified name
|
||||
name = f"{func.__module__}.{func.__qualname__}"
|
||||
record_timing(name, duration)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ProfileContext:
|
||||
"""Context manager for profiling code blocks.
|
||||
|
||||
Usage:
|
||||
with ProfileContext("my_operation"):
|
||||
# ... code to profile ...
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.start = None
|
||||
|
||||
def __enter__(self):
|
||||
if _profiling_enabled:
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if _profiling_enabled and self.start is not None:
|
||||
duration = time.perf_counter() - self.start
|
||||
record_timing(self.name, duration)
|
||||
|
||||
|
||||
def get_profiling_stats() -> dict[str, dict[str, float]]:
|
||||
"""Get summary statistics for all profiled functions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping function names to their stats (count, mean, min, max, total)
|
||||
"""
|
||||
with _profiling_lock:
|
||||
summary = {}
|
||||
for name, times in _profiling_stats.items():
|
||||
if times:
|
||||
summary[name] = {
|
||||
"count": len(times),
|
||||
"mean": sum(times) / len(times),
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"total": sum(times),
|
||||
"mean_ms": (sum(times) / len(times)) * 1000,
|
||||
"min_ms": min(times) * 1000,
|
||||
"max_ms": max(times) * 1000,
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def clear_profiling_stats():
|
||||
"""Clear all profiling statistics."""
|
||||
with _profiling_lock:
|
||||
_profiling_stats.clear()
|
||||
logger.info("Profiling stats cleared")
|
||||
|
||||
|
||||
def print_profiling_summary(sort_by: str = "total"):
|
||||
"""Print formatted summary of profiling statistics.
|
||||
|
||||
Args:
|
||||
sort_by: Sort key ('total', 'mean', 'count', 'max')
|
||||
"""
|
||||
summary = get_profiling_stats()
|
||||
|
||||
if not summary:
|
||||
logger.info("No profiling data available")
|
||||
return
|
||||
|
||||
logger.info("\n" + "=" * 100)
|
||||
logger.info("PROFILING SUMMARY")
|
||||
logger.info("=" * 100)
|
||||
|
||||
# Sort by requested key
|
||||
sorted_items = sorted(summary.items(), key=lambda x: x[1].get(sort_by, 0), reverse=True)
|
||||
|
||||
# Print header
|
||||
logger.info(
|
||||
f"{'Function':<60} {'Count':>8} {'Mean (ms)':>12} {'Min (ms)':>12} {'Max (ms)':>12} {'Total (s)':>12}"
|
||||
)
|
||||
logger.info("-" * 100)
|
||||
|
||||
# Print each function's stats
|
||||
for name, stats in sorted_items:
|
||||
# Shorten long names
|
||||
display_name = name if len(name) <= 60 else "..." + name[-57:]
|
||||
|
||||
logger.info(
|
||||
f"{display_name:<60} "
|
||||
f"{stats['count']:>8} "
|
||||
f"{stats['mean_ms']:>12.2f} "
|
||||
f"{stats['min_ms']:>12.2f} "
|
||||
f"{stats['max_ms']:>12.2f} "
|
||||
f"{stats['total']:>12.2f}"
|
||||
)
|
||||
|
||||
logger.info("=" * 100)
|
||||
|
||||
# Print summary
|
||||
total_time = sum(s["total"] for s in summary.values())
|
||||
total_calls = sum(s["count"] for s in summary.values())
|
||||
logger.info(f"\nTotal profiled time: {total_time:.2f}s across {total_calls} calls")
|
||||
logger.info("=" * 100 + "\n")
|
||||
|
||||
|
||||
def profile_section(name: str):
|
||||
"""Return a context manager for profiling a code section.
|
||||
|
||||
Args:
|
||||
name: Name for this section
|
||||
|
||||
Returns:
|
||||
ProfileContext instance
|
||||
|
||||
Usage:
|
||||
with profile_section("data_loading"):
|
||||
data = load_data()
|
||||
"""
|
||||
return ProfileContext(name)
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test PI0.5 policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_initialization():
|
||||
"""Test PI0.5 policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ PI0.5 RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_initialization_without_rtc_config():
|
||||
"""Test PI0.5 policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ PI0.5 RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_inference_with_prev_chunk():
|
||||
"""Test PI0.5 policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ PI0.5 RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_inference_without_prev_chunk():
|
||||
"""Test PI0.5 policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ PI0.5 RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_validation_rules():
|
||||
"""Test PI0.5 policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_initialization():
|
||||
"""Test PI0 policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ PI0 RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_initialization_without_rtc_config():
|
||||
"""Test PI0 policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ PI0 RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
def test_pi0_rtc_inference_with_prev_chunk():
|
||||
"""Test PI0 policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ PI0 RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_inference_without_prev_chunk():
|
||||
"""Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ PI0 RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_validation_rules():
|
||||
"""Test PI0 policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
"""Test PI0 with different RTC attention schedules."""
|
||||
set_seed(42)
|
||||
|
||||
schedules = [
|
||||
RTCAttentionSchedule.ZEROS,
|
||||
RTCAttentionSchedule.ONES,
|
||||
RTCAttentionSchedule.LINEAR,
|
||||
RTCAttentionSchedule.EXP,
|
||||
]
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
device = config.device
|
||||
|
||||
for schedule in schedules:
|
||||
print(f"Testing schedule: {schedule}")
|
||||
|
||||
# Add RTC config with specific schedule
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=schedule,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Verify shape
|
||||
assert actions.shape == (1, config.chunk_size, 7)
|
||||
print(f" ✓ Schedule {schedule}: Test passed")
|
||||
|
||||
print("✓ PI0 RTC different schedules: All schedules tested")
|
||||
@@ -0,0 +1,825 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC ActionQueue module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_enabled():
|
||||
"""Create an RTC config with RTC enabled."""
|
||||
return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_disabled():
|
||||
"""Create an RTC config with RTC disabled."""
|
||||
return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_actions():
|
||||
"""Create sample action tensors for testing."""
|
||||
return {
|
||||
"original": torch.randn(50, 6), # (time_steps, action_dim)
|
||||
"processed": torch.randn(50, 6),
|
||||
"short": torch.randn(10, 6),
|
||||
"longer": torch.randn(100, 6),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_enabled(rtc_config_enabled):
|
||||
"""Create an ActionQueue with RTC enabled."""
|
||||
return ActionQueue(rtc_config_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_disabled(rtc_config_disabled):
|
||||
"""Create an ActionQueue with RTC disabled."""
|
||||
return ActionQueue(rtc_config_disabled)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
|
||||
"""Test ActionQueue initializes correctly with RTC enabled."""
|
||||
queue = ActionQueue(rtc_config_enabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is True
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
|
||||
"""Test ActionQueue initializes correctly with RTC disabled."""
|
||||
queue = ActionQueue(rtc_config_disabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is False
|
||||
|
||||
|
||||
# ====================== get() Tests ======================
|
||||
|
||||
|
||||
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get() returns None when queue is empty."""
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns actions in sequence."""
|
||||
# Initialize queue with actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Get first action
|
||||
action1 = action_queue_rtc_enabled.get()
|
||||
assert action1 is not None
|
||||
assert action1.shape == (6,)
|
||||
assert torch.equal(action1, sample_actions["processed"][0])
|
||||
|
||||
# Get second action
|
||||
action2 = action_queue_rtc_enabled.get()
|
||||
assert action2 is not None
|
||||
assert torch.equal(action2, sample_actions["processed"][1])
|
||||
|
||||
|
||||
def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns None after all actions are consumed."""
|
||||
# Use short action sequence
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
# Next get should return None
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() increments last_index correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
|
||||
# ====================== qsize() Tests ======================
|
||||
|
||||
|
||||
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
|
||||
"""Test qsize() returns 0 when queue is empty."""
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns correct number of remaining actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 9
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
|
||||
def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns 0 after queue is exhausted."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== empty() Tests ======================
|
||||
|
||||
|
||||
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
|
||||
"""Test empty() returns True when queue is empty."""
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False when queue has actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False after partial consumption."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns True after all actions consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
# ====================== get_action_index() Tests ======================
|
||||
|
||||
|
||||
def test_get_action_index_initial_value(action_queue_rtc_enabled):
|
||||
"""Test get_action_index() returns 0 initially."""
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_action_index() tracks consumption correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 3
|
||||
|
||||
|
||||
# ====================== get_left_over() Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get_left_over() returns None when queue is empty."""
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is None
|
||||
|
||||
|
||||
def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns all original actions when none consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (10, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"])
|
||||
|
||||
|
||||
def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns only remaining original actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (7, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"][3:])
|
||||
|
||||
|
||||
def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns empty tensor after all consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (0, 6)
|
||||
|
||||
|
||||
# ====================== merge() with RTC Enabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() replaces queue when RTC is enabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
# Consume some actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
# Merge new actions - should replace, not append
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
# Queue should be replaced with new actions minus delay
|
||||
# Original has 50 actions, delay is 5, so remaining is 45
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() correctly applies real_delay when RTC is enabled."""
|
||||
delay = 10
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay)
|
||||
|
||||
# Queue should have original length minus delay
|
||||
expected_size = len(sample_actions["original"]) - delay
|
||||
assert action_queue_rtc_enabled.qsize() == expected_size
|
||||
|
||||
# First action should be the one at index [delay]
|
||||
first_action = action_queue_rtc_enabled.get()
|
||||
assert torch.equal(first_action, sample_actions["processed"][delay])
|
||||
|
||||
|
||||
def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() resets last_index to 0 when RTC is enabled."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
# Merge new actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with zero delay keeps all actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with delay larger than action sequence."""
|
||||
# Delay is larger than sequence length
|
||||
delay = 100
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay)
|
||||
|
||||
# Queue should be empty (delay >= length)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== merge() with RTC Disabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() appends actions when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
initial_size = action_queue_rtc_disabled.qsize()
|
||||
assert initial_size == 10
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have appended
|
||||
assert action_queue_rtc_disabled.qsize() == initial_size + 10
|
||||
|
||||
|
||||
def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() removes consumed actions before appending when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.qsize() == 7
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have 7 remaining + 10 new = 17
|
||||
assert action_queue_rtc_disabled.qsize() == 17
|
||||
|
||||
|
||||
def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() resets last_index after appending when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.last_index == 2
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# last_index should be reset to 0
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() ignores real_delay parameter when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10)
|
||||
|
||||
# All actions should be in queue (delay ignored)
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() on first call with RTC disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
# ====================== merge() with Different Action Shapes Tests ======================
|
||||
|
||||
|
||||
def test_merge_with_different_action_dims():
|
||||
"""Test merge() handles actions with different dimensions."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Actions with 4 dimensions instead of 6
|
||||
actions_4d = torch.randn(20, 4)
|
||||
queue.merge(actions_4d, actions_4d, real_delay=5)
|
||||
|
||||
action = queue.get()
|
||||
assert action.shape == (4,)
|
||||
|
||||
|
||||
def test_merge_with_different_lengths():
|
||||
"""Test merge() handles action sequences of varying lengths."""
|
||||
cfg = RTCConfig(enabled=False, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Add sequences of different lengths
|
||||
queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0)
|
||||
assert queue.qsize() == 10
|
||||
|
||||
queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0)
|
||||
assert queue.qsize() == 35
|
||||
|
||||
|
||||
# ====================== merge() Delay Validation Tests ======================
|
||||
|
||||
|
||||
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() validates that real_delay matches action index difference."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with mismatched delay (should log warning)
|
||||
# We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0,
|
||||
# then indexes_diff=5, but if real_delay=3, it will warn
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=3,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Check warning was logged
|
||||
assert "Indexes diff is not equal to real delay" in caplog.text
|
||||
|
||||
|
||||
def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() doesn't warn when delays are consistent."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with matching delay
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Should not have warning
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() skips delay validation when action_index_before_inference is None."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Pass None for action_index_before_inference
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=999, # Doesn't matter
|
||||
action_index_before_inference=None,
|
||||
)
|
||||
|
||||
# Should not warn (validation skipped)
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
# ====================== Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() is thread-safe with multiple consumers."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(25):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
if action is not None:
|
||||
results.append(action)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=consumer) for _ in range(4)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed all actions (100 total, 4 threads * 25 each)
|
||||
assert len(results) == 100
|
||||
|
||||
# All results should be unique (no duplicate consumption)
|
||||
# We can verify by checking that indices are not duplicated
|
||||
# Since we don't track indices in results, we check total count is correct
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() is thread-safe with multiple producers."""
|
||||
errors = []
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(5):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=producer) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150)
|
||||
assert action_queue_rtc_disabled.qsize() == 150
|
||||
|
||||
|
||||
def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test concurrent get() and merge() operations."""
|
||||
errors = []
|
||||
consumed_count = [0]
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(50):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
if action is not None:
|
||||
consumed_count[0] += 1
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_threads = [threading.Thread(target=consumer) for _ in range(2)]
|
||||
producer_threads = [threading.Thread(target=producer) for _ in range(2)]
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.start()
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed some or all actions (non-deterministic due to timing)
|
||||
# Total produced: 2 producers * 10 merges * 10 actions = 200
|
||||
# Total consumed attempts: 2 consumers * 50 = 100
|
||||
assert consumed_count[0] <= 200
|
||||
|
||||
|
||||
# ====================== get_left_over() Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() is thread-safe with concurrent access."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
errors = []
|
||||
leftovers = []
|
||||
|
||||
def reader():
|
||||
try:
|
||||
for _ in range(20):
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
if leftover is not None:
|
||||
leftovers.append(leftover.shape[0])
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=reader) for _ in range(3)]
|
||||
|
||||
# Also consume some actions concurrently
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
time.sleep(0.002)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
all_threads = threads + [consumer_thread]
|
||||
|
||||
for t in all_threads:
|
||||
t.start()
|
||||
|
||||
for t in all_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Leftovers should be monotonically decreasing or stable
|
||||
# (as actions are consumed, leftover size decreases)
|
||||
assert len(leftovers) > 0
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_queue_with_single_action(action_queue_rtc_enabled):
|
||||
"""Test queue behavior with a single action."""
|
||||
single_action_original = torch.randn(1, 6)
|
||||
single_action_processed = torch.randn(1, 6)
|
||||
|
||||
action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 1
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
assert action.shape == (6,)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test queue maintains correct state through multiple merge cycles."""
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume half
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge again
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() > 0
|
||||
|
||||
|
||||
def test_queue_with_all_zeros_actions(action_queue_rtc_enabled):
|
||||
"""Test queue handles all-zero action tensors."""
|
||||
zeros_actions = torch.zeros(20, 6)
|
||||
action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0)
|
||||
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert torch.all(action == 0)
|
||||
|
||||
|
||||
def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test that merge() clones input tensors, not storing references."""
|
||||
original_copy = sample_actions["original"].clone()
|
||||
processed_copy = sample_actions["processed"].clone()
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Modify original tensors
|
||||
sample_actions["original"].fill_(999.0)
|
||||
sample_actions["processed"].fill_(-999.0)
|
||||
|
||||
# Queue should have cloned values
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert not torch.equal(action, sample_actions["processed"][0])
|
||||
assert torch.equal(action, processed_copy[0])
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert not torch.equal(leftover, sample_actions["original"][1:])
|
||||
assert torch.equal(leftover, original_copy[1:])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_queue_handles_gpu_tensors():
|
||||
"""Test queue correctly handles GPU tensors."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
actions_gpu = torch.randn(20, 6, device="cuda")
|
||||
queue.merge(actions_gpu, actions_gpu, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.device.type == "cuda"
|
||||
|
||||
leftover = queue.get_left_over()
|
||||
assert leftover.device.type == "cuda"
|
||||
|
||||
|
||||
def test_queue_handles_different_dtypes():
|
||||
"""Test queue handles actions with different dtypes."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Use float64 instead of default float32
|
||||
actions_f64 = torch.randn(20, 6, dtype=torch.float64)
|
||||
queue.merge(actions_f64, actions_f64, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.dtype == torch.float64
|
||||
|
||||
|
||||
def test_empty_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test empty() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_qsize_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test qsize() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test a typical RTC workflow: merge, consume, merge with delay."""
|
||||
# First inference
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
initial_size = action_queue_rtc_enabled.qsize()
|
||||
assert initial_size == 50
|
||||
|
||||
# Consume 10 actions (execution_horizon)
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 40
|
||||
|
||||
# Second inference with delay
|
||||
action_index_before = action_queue_rtc_enabled.get_action_index()
|
||||
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=action_index_before,
|
||||
)
|
||||
|
||||
# Queue should be replaced, minus delay
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test a typical non-RTC workflow: merge, consume, merge again."""
|
||||
# First inference
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 50
|
||||
|
||||
# Consume 40 actions
|
||||
for _ in range(40):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Second inference (should append)
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Should have 10 remaining + 50 new = 60
|
||||
assert action_queue_rtc_disabled.qsize() == 60
|
||||
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC configuration module."""
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_default_initialization():
|
||||
"""Test RTCConfig initializes with default values."""
|
||||
config = RTCConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
assert config.debug_maxlen == 100
|
||||
|
||||
|
||||
def test_rtc_config_custom_initialization():
|
||||
"""Test RTCConfig initializes with custom values."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=20,
|
||||
debug=True,
|
||||
debug_maxlen=200,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
assert config.max_guidance_weight == 5.0
|
||||
assert config.execution_horizon == 20
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 200
|
||||
|
||||
|
||||
def test_rtc_config_partial_initialization():
|
||||
"""Test RTCConfig with partial custom values."""
|
||||
config = RTCConfig(enabled=True, max_guidance_weight=15.0)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.max_guidance_weight == 15.0
|
||||
# Other values should be defaults
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
@@ -0,0 +1,488 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC debug tracker module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.debug_tracker import DebugStep, Tracker
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors():
|
||||
"""Create sample tensors for testing."""
|
||||
return {
|
||||
"x_t": torch.randn(1, 50, 6),
|
||||
"v_t": torch.randn(1, 50, 6),
|
||||
"x1_t": torch.randn(1, 50, 6),
|
||||
"correction": torch.randn(1, 50, 6),
|
||||
"err": torch.randn(1, 50, 6),
|
||||
"weights": torch.randn(1, 50, 1),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enabled_tracker():
|
||||
"""Create an enabled tracker with default settings."""
|
||||
return Tracker(enabled=True, maxlen=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_tracker():
|
||||
"""Create a disabled tracker."""
|
||||
return Tracker(enabled=False)
|
||||
|
||||
|
||||
# ====================== DebugStep Tests ======================
|
||||
|
||||
|
||||
def test_debug_step_initialization():
|
||||
"""Test that DebugStep can be initialized with default values."""
|
||||
step = DebugStep()
|
||||
assert step.step_idx == 0
|
||||
assert step.x_t is None
|
||||
assert step.v_t is None
|
||||
assert step.x1_t is None
|
||||
assert step.correction is None
|
||||
assert step.err is None
|
||||
assert step.weights is None
|
||||
assert step.guidance_weight is None
|
||||
assert step.time is None
|
||||
assert step.inference_delay is None
|
||||
assert step.execution_horizon is None
|
||||
assert step.metadata == {}
|
||||
|
||||
|
||||
def test_debug_step_with_values(sample_tensors):
|
||||
"""Test DebugStep initialization with actual values."""
|
||||
step = DebugStep(
|
||||
step_idx=5,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
x1_t=sample_tensors["x1_t"],
|
||||
correction=sample_tensors["correction"],
|
||||
err=sample_tensors["err"],
|
||||
weights=sample_tensors["weights"],
|
||||
guidance_weight=2.5,
|
||||
time=0.8,
|
||||
inference_delay=4,
|
||||
execution_horizon=8,
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert step.step_idx == 5
|
||||
assert torch.equal(step.x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(step.v_t, sample_tensors["v_t"])
|
||||
assert torch.equal(step.x1_t, sample_tensors["x1_t"])
|
||||
assert torch.equal(step.correction, sample_tensors["correction"])
|
||||
assert torch.equal(step.err, sample_tensors["err"])
|
||||
assert torch.equal(step.weights, sample_tensors["weights"])
|
||||
assert step.guidance_weight == 2.5
|
||||
assert step.time == 0.8
|
||||
assert step.inference_delay == 4
|
||||
assert step.execution_horizon == 8
|
||||
assert step.metadata == {"custom_key": "custom_value"}
|
||||
|
||||
|
||||
def test_debug_step_to_dict_without_tensors(sample_tensors):
|
||||
"""Test converting DebugStep to dictionary without tensor values."""
|
||||
step = DebugStep(
|
||||
step_idx=3,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=torch.tensor(3.0),
|
||||
time=torch.tensor(0.5),
|
||||
inference_delay=2,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
result = step.to_dict(include_tensors=False)
|
||||
|
||||
assert result["step_idx"] == 3
|
||||
assert result["guidance_weight"] == 3.0
|
||||
assert result["time"] == 0.5
|
||||
assert result["inference_delay"] == 2
|
||||
assert result["execution_horizon"] == 10
|
||||
|
||||
# Check tensor statistics are included
|
||||
assert "x_t_stats" in result
|
||||
assert "v_t_stats" in result
|
||||
assert "x1_t_stats" not in result # x1_t was None
|
||||
|
||||
# Verify statistics structure
|
||||
assert "shape" in result["x_t_stats"]
|
||||
assert "mean" in result["x_t_stats"]
|
||||
assert "std" in result["x_t_stats"]
|
||||
assert "min" in result["x_t_stats"]
|
||||
assert "max" in result["x_t_stats"]
|
||||
|
||||
# Verify shape matches original tensor
|
||||
assert result["x_t_stats"]["shape"] == tuple(sample_tensors["x_t"].shape)
|
||||
|
||||
|
||||
def test_debug_step_to_dict_with_tensors(sample_tensors):
|
||||
"""Test converting DebugStep to dictionary with tensor values."""
|
||||
step = DebugStep(
|
||||
step_idx=1,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=1.5,
|
||||
time=0.9,
|
||||
)
|
||||
|
||||
result = step.to_dict(include_tensors=True)
|
||||
|
||||
assert result["step_idx"] == 1
|
||||
assert result["guidance_weight"] == 1.5
|
||||
assert result["time"] == 0.9
|
||||
|
||||
# Check tensors are included (as CPU tensors)
|
||||
assert "x_t" in result
|
||||
assert "v_t" in result
|
||||
assert isinstance(result["x_t"], torch.Tensor)
|
||||
assert isinstance(result["v_t"], torch.Tensor)
|
||||
assert result["x_t"].device.type == "cpu"
|
||||
assert result["v_t"].device.type == "cpu"
|
||||
|
||||
|
||||
def test_debug_step_to_dict_with_none_guidance_weight():
|
||||
"""Test to_dict handles None guidance_weight correctly."""
|
||||
step = DebugStep(step_idx=0, time=1.0, guidance_weight=None)
|
||||
result = step.to_dict(include_tensors=False)
|
||||
assert result["guidance_weight"] is None
|
||||
|
||||
|
||||
def test_tracker_initialization_enabled():
|
||||
"""Test tracker initialization when enabled."""
|
||||
tracker = Tracker(enabled=True, maxlen=50)
|
||||
assert tracker.enabled is True
|
||||
assert tracker._steps == {}
|
||||
assert tracker._maxlen == 50
|
||||
assert tracker._step_counter == 0
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
def test_tracker_reset_when_enabled(enabled_tracker, sample_tensors):
|
||||
"""Test reset clears all steps when tracker is enabled."""
|
||||
# Add some steps
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
# Reset
|
||||
enabled_tracker.reset()
|
||||
assert len(enabled_tracker) == 0
|
||||
assert enabled_tracker._step_counter == 0
|
||||
assert enabled_tracker._steps == {}
|
||||
|
||||
|
||||
def test_tracker_reset_when_disabled(disabled_tracker):
|
||||
"""Test reset on disabled tracker doesn't cause errors."""
|
||||
disabled_tracker.reset()
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
# ====================== Tracker.track() Tests ======================
|
||||
|
||||
|
||||
def test_track_creates_new_step(enabled_tracker, sample_tensors):
|
||||
"""Test that track creates a new step when time doesn't exist."""
|
||||
enabled_tracker.track(
|
||||
time=1.0,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=5.0,
|
||||
inference_delay=4,
|
||||
execution_horizon=8,
|
||||
)
|
||||
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].step_idx == 0
|
||||
assert steps[0].time == 1.0
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
|
||||
assert steps[0].guidance_weight == 5.0
|
||||
assert steps[0].inference_delay == 4
|
||||
assert steps[0].execution_horizon == 8
|
||||
|
||||
|
||||
def test_track_updates_existing_step(enabled_tracker, sample_tensors):
|
||||
"""Test that track updates an existing step at the same time."""
|
||||
# Create initial step
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].v_t is None
|
||||
|
||||
# Update the same timestep with v_t
|
||||
enabled_tracker.track(time=0.9, v_t=sample_tensors["v_t"])
|
||||
assert len(enabled_tracker) == 1 # Still only one step
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Original x_t preserved
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # New v_t added
|
||||
|
||||
|
||||
def test_track_with_tensor_time(enabled_tracker, sample_tensors):
|
||||
"""Test track handles tensor time values correctly."""
|
||||
time_tensor = torch.tensor(0.8)
|
||||
enabled_tracker.track(time=time_tensor, x_t=sample_tensors["x_t"])
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert abs(steps[0].time - 0.8) < 1e-6 # Use approximate comparison for floating point
|
||||
|
||||
|
||||
def test_track_time_rounding(enabled_tracker, sample_tensors):
|
||||
"""Test that track rounds time to avoid floating point precision issues."""
|
||||
# These times should be treated as the same after rounding to 6 decimals
|
||||
enabled_tracker.track(time=0.9000001, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9000002, v_t=sample_tensors["v_t"])
|
||||
|
||||
# Should still be one step (times rounded to same value)
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
|
||||
|
||||
|
||||
def test_track_does_nothing_when_disabled(disabled_tracker, sample_tensors):
|
||||
"""Test that track does nothing when tracker is disabled."""
|
||||
disabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
def test_track_with_metadata(enabled_tracker, sample_tensors):
|
||||
"""Test track stores custom metadata."""
|
||||
enabled_tracker.track(time=0.7, x_t=sample_tensors["x_t"], custom_field="custom_value", count=42)
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].metadata["custom_field"] == "custom_value"
|
||||
assert steps[0].metadata["count"] == 42
|
||||
|
||||
|
||||
def test_track_updates_metadata(enabled_tracker):
|
||||
"""Test that track updates metadata for existing steps."""
|
||||
enabled_tracker.track(time=0.6, meta1="value1")
|
||||
enabled_tracker.track(time=0.6, meta2="value2")
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].metadata["meta1"] == "value1"
|
||||
assert steps[0].metadata["meta2"] == "value2"
|
||||
|
||||
|
||||
def test_track_clones_tensors(enabled_tracker, sample_tensors):
|
||||
"""Test that track clones tensors instead of storing references."""
|
||||
x_t_original = sample_tensors["x_t"].clone()
|
||||
enabled_tracker.track(time=0.5, x_t=sample_tensors["x_t"])
|
||||
|
||||
# Modify original tensor
|
||||
sample_tensors["x_t"].fill_(999.0)
|
||||
|
||||
# Tracked tensor should not be affected
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert not torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].x_t, x_t_original)
|
||||
|
||||
|
||||
def test_track_with_none_values(enabled_tracker):
|
||||
"""Test track handles None values correctly."""
|
||||
enabled_tracker.track(
|
||||
time=0.4,
|
||||
x_t=None,
|
||||
v_t=None,
|
||||
guidance_weight=None,
|
||||
inference_delay=None,
|
||||
)
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].x_t is None
|
||||
assert steps[0].v_t is None
|
||||
assert steps[0].guidance_weight is None
|
||||
assert steps[0].inference_delay is None
|
||||
|
||||
|
||||
def test_track_updates_only_non_none_fields(enabled_tracker, sample_tensors):
|
||||
"""Test that update preserves existing values when None is passed."""
|
||||
# Create step with x_t
|
||||
enabled_tracker.track(time=0.3, x_t=sample_tensors["x_t"], guidance_weight=2.0)
|
||||
|
||||
# Update with v_t only (pass None for other fields)
|
||||
enabled_tracker.track(time=0.3, v_t=sample_tensors["v_t"], x_t=None, guidance_weight=None)
|
||||
|
||||
# Original values should be preserved
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Still has x_t
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # Now has v_t
|
||||
assert steps[0].guidance_weight == 2.0 # Still has guidance_weight
|
||||
|
||||
|
||||
# ====================== Tracker.maxlen Tests ======================
|
||||
|
||||
|
||||
def test_tracker_enforces_maxlen():
|
||||
"""Test that tracker enforces maxlen limit."""
|
||||
tracker = Tracker(enabled=True, maxlen=3)
|
||||
|
||||
# Add 5 steps
|
||||
for i in range(5):
|
||||
time = 1.0 - i * 0.1 # 1.0, 0.9, 0.8, 0.7, 0.6
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
# Should only keep the last 3
|
||||
assert len(tracker) == 3
|
||||
|
||||
# Verify oldest steps were removed (should have 0.6, 0.7, 0.8)
|
||||
steps = tracker.get_all_steps()
|
||||
times = sorted([step.time for step in steps])
|
||||
assert times == [0.6, 0.7, 0.8]
|
||||
|
||||
|
||||
def test_tracker_step_idx_increments_despite_maxlen():
|
||||
"""Test that step_idx continues incrementing even when maxlen is enforced."""
|
||||
tracker = Tracker(enabled=True, maxlen=2)
|
||||
|
||||
# Add 4 steps
|
||||
for i in range(4):
|
||||
time = 1.0 - i * 0.1
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
# Should have 2 steps with step_idx 2 and 3 (oldest removed)
|
||||
steps = sorted(tracker.get_all_steps(), key=lambda s: s.step_idx)
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step_idx == 2
|
||||
assert steps[1].step_idx == 3
|
||||
|
||||
|
||||
def test_tracker_without_maxlen_keeps_all():
|
||||
"""Test that tracker without maxlen keeps all steps."""
|
||||
tracker = Tracker(enabled=True, maxlen=None)
|
||||
|
||||
# Add 100 steps
|
||||
for i in range(100):
|
||||
time = 1.0 - i * 0.01
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
assert len(tracker) == 100
|
||||
|
||||
|
||||
def test_get_all_steps_returns_empty_when_disabled(disabled_tracker):
|
||||
"""Test get_all_steps returns empty list when disabled."""
|
||||
steps = disabled_tracker.get_all_steps()
|
||||
assert steps == []
|
||||
assert isinstance(steps, list)
|
||||
|
||||
|
||||
def test_get_all_steps_returns_empty_when_no_steps(enabled_tracker):
|
||||
"""Test get_all_steps returns empty list when no steps tracked."""
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps == []
|
||||
|
||||
|
||||
def test_get_all_steps_returns_all_tracked_steps(enabled_tracker, sample_tensors):
|
||||
"""Test get_all_steps returns all tracked steps."""
|
||||
# Track 5 steps
|
||||
for i in range(5):
|
||||
time = 1.0 - i * 0.1
|
||||
enabled_tracker.track(time=time, x_t=sample_tensors["x_t"])
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 5
|
||||
|
||||
# Verify all are DebugStep instances
|
||||
for step in steps:
|
||||
assert isinstance(step, DebugStep)
|
||||
|
||||
|
||||
def test_get_all_steps_preserves_insertion_order(enabled_tracker):
|
||||
"""Test that get_all_steps preserves insertion order (Python 3.7+)."""
|
||||
times = [0.9, 0.8, 0.7, 0.6, 0.5]
|
||||
for time in times:
|
||||
enabled_tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
retrieved_times = [step.time for step in steps]
|
||||
|
||||
# Should be in insertion order
|
||||
assert retrieved_times == times
|
||||
|
||||
|
||||
# ====================== Tracker.__len__() Tests ======================
|
||||
|
||||
|
||||
def test_len_returns_zero_when_disabled(disabled_tracker):
|
||||
"""Test __len__ returns 0 when tracker is disabled."""
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
def test_len_returns_zero_when_empty(enabled_tracker):
|
||||
"""Test __len__ returns 0 when no steps are tracked."""
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
|
||||
def test_len_returns_correct_count(enabled_tracker, sample_tensors):
|
||||
"""Test __len__ returns correct number of tracked steps."""
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 1
|
||||
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
enabled_tracker.track(time=0.8, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 3
|
||||
|
||||
|
||||
def test_len_after_reset(enabled_tracker, sample_tensors):
|
||||
"""Test __len__ returns 0 after reset."""
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
enabled_tracker.reset()
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tracker_handles_gpu_tensors():
|
||||
"""Test tracker correctly handles GPU tensors."""
|
||||
tracker = Tracker(enabled=True, maxlen=10)
|
||||
x_t_gpu = torch.randn(1, 50, 6, device="cuda")
|
||||
|
||||
tracker.track(time=1.0, x_t=x_t_gpu)
|
||||
|
||||
steps = tracker.get_all_steps()
|
||||
# Tracker should clone and detach tensors
|
||||
assert steps[0].x_t.device.type == "cuda"
|
||||
|
||||
|
||||
def test_tracker_with_varying_tensor_shapes(enabled_tracker):
|
||||
"""Test tracker handles varying tensor shapes across steps."""
|
||||
enabled_tracker.track(time=1.0, x_t=torch.randn(1, 50, 6))
|
||||
enabled_tracker.track(time=0.9, x_t=torch.randn(1, 25, 6))
|
||||
enabled_tracker.track(time=0.8, x_t=torch.randn(2, 50, 8))
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 3
|
||||
assert steps[0].x_t.shape == (1, 50, 6)
|
||||
assert steps[1].x_t.shape == (1, 25, 6)
|
||||
assert steps[2].x_t.shape == (2, 50, 8)
|
||||
@@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC LatencyTracker module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker():
|
||||
"""Create a LatencyTracker with default maxlen."""
|
||||
return LatencyTracker(maxlen=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_tracker():
|
||||
"""Create a LatencyTracker with small maxlen for overflow testing."""
|
||||
return LatencyTracker(maxlen=5)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_latency_tracker_initialization():
|
||||
"""Test LatencyTracker initializes correctly."""
|
||||
tracker = LatencyTracker(maxlen=50)
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_latency_tracker_default_maxlen():
|
||||
"""Test LatencyTracker uses default maxlen."""
|
||||
tracker = LatencyTracker()
|
||||
# Should accept default maxlen=100
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
# ====================== add() Tests ======================
|
||||
|
||||
|
||||
def test_add_single_latency(tracker):
|
||||
"""Test adding a single latency value."""
|
||||
tracker.add(0.5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_multiple_latencies(tracker):
|
||||
"""Test adding multiple latency values."""
|
||||
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert len(tracker) == 5
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_add_negative_latency_ignored(tracker):
|
||||
"""Test that negative latencies are ignored."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(-0.1)
|
||||
tracker.add(0.3)
|
||||
|
||||
# Should only have 2 valid latencies
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_zero_latency(tracker):
|
||||
"""Test adding zero latency."""
|
||||
tracker.add(0.0)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_add_converts_to_float(tracker):
|
||||
"""Test add() converts input to float."""
|
||||
tracker.add(5) # Integer
|
||||
tracker.add("3.5") # String
|
||||
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_updates_max_latency(tracker):
|
||||
"""Test that max_latency is updated correctly."""
|
||||
tracker.add(0.5)
|
||||
assert tracker.max_latency == 0.5
|
||||
|
||||
tracker.add(0.3)
|
||||
assert tracker.max_latency == 0.5 # Should not decrease
|
||||
|
||||
tracker.add(0.9)
|
||||
assert tracker.max_latency == 0.9 # Should increase
|
||||
|
||||
|
||||
# ====================== reset() Tests ======================
|
||||
|
||||
|
||||
def test_reset_clears_values(tracker):
|
||||
"""Test reset() clears all values."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 3
|
||||
|
||||
tracker.reset()
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_clears_max_latency(tracker):
|
||||
"""Test reset() resets max_latency."""
|
||||
tracker.add(1.5)
|
||||
assert tracker.max_latency == 1.5
|
||||
|
||||
tracker.reset()
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_allows_new_values(tracker):
|
||||
"""Test that tracker works correctly after reset."""
|
||||
tracker.add(0.5)
|
||||
tracker.reset()
|
||||
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.3
|
||||
|
||||
|
||||
# ====================== max() Tests ======================
|
||||
|
||||
|
||||
def test_max_returns_zero_when_empty(tracker):
|
||||
"""Test max() returns 0.0 when tracker is empty."""
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_max_returns_maximum_value(tracker):
|
||||
"""Test max() returns the maximum latency."""
|
||||
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_max_persists_after_sliding_window(small_tracker):
|
||||
"""Test max() persists even after values slide out of window."""
|
||||
# Add values that will exceed maxlen=5
|
||||
small_tracker.add(0.1)
|
||||
small_tracker.add(0.9) # This is max
|
||||
small_tracker.add(0.2)
|
||||
small_tracker.add(0.3)
|
||||
small_tracker.add(0.4)
|
||||
small_tracker.add(0.5) # This pushes out 0.1
|
||||
|
||||
# Max should still be 0.9 even though only last 5 values kept
|
||||
assert small_tracker.max() == 0.9
|
||||
|
||||
|
||||
def test_max_after_reset(tracker):
|
||||
"""Test max() returns 0.0 after reset."""
|
||||
tracker.add(1.5)
|
||||
tracker.reset()
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
# ====================== p95() Tests ======================
|
||||
|
||||
|
||||
def test_p95_returns_zero_when_empty(tracker):
|
||||
"""Test p95() returns 0.0 when tracker is empty."""
|
||||
assert tracker.p95() == 0.0
|
||||
|
||||
|
||||
def test_p95_returns_95th_percentile(tracker):
|
||||
"""Test p95() returns the 95th percentile."""
|
||||
# Add 100 values
|
||||
for i in range(100):
|
||||
tracker.add(i / 100.0)
|
||||
|
||||
p95 = tracker.p95()
|
||||
assert 0.93 <= p95 <= 0.96
|
||||
|
||||
|
||||
def test_p95_equals_percentile_95(tracker):
|
||||
"""Test p95() equals percentile(0.95)."""
|
||||
for i in range(50):
|
||||
tracker.add(i / 50.0)
|
||||
|
||||
assert tracker.p95() == tracker.percentile(0.95)
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_single_value(tracker):
|
||||
"""Test tracker behavior with single value."""
|
||||
tracker.add(0.75)
|
||||
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.75
|
||||
assert tracker.percentile(0.0) == 0.75
|
||||
assert tracker.percentile(0.5) == 0.75
|
||||
assert tracker.percentile(1.0) == 0.75
|
||||
|
||||
|
||||
def test_all_same_values(tracker):
|
||||
"""Test tracker with all identical values."""
|
||||
for _ in range(10):
|
||||
tracker.add(0.5)
|
||||
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.5
|
||||
assert tracker.percentile(0.0) == 0.5
|
||||
assert tracker.percentile(0.5) == 0.5
|
||||
assert tracker.percentile(1.0) == 0.5
|
||||
|
||||
|
||||
def test_very_small_values(tracker):
|
||||
"""Test tracker with very small float values."""
|
||||
tracker.add(1e-10)
|
||||
tracker.add(2e-10)
|
||||
tracker.add(3e-10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e-10)
|
||||
|
||||
|
||||
def test_very_large_values(tracker):
|
||||
"""Test tracker with very large float values."""
|
||||
tracker.add(1e10)
|
||||
tracker.add(2e10)
|
||||
tracker.add(3e10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e10)
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_usage_pattern(tracker):
|
||||
"""Test a typical usage pattern of the tracker."""
|
||||
# Simulate adding latencies over time
|
||||
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
|
||||
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
# Check statistics
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.15
|
||||
|
||||
# p95 should be close to max since we have only 10 values
|
||||
p95 = tracker.p95()
|
||||
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
|
||||
assert p95 <= tracker.max() # p95 should be <= max
|
||||
|
||||
|
||||
def test_reset_and_reuse(tracker):
|
||||
"""Test resetting and reusing tracker."""
|
||||
# First batch
|
||||
tracker.add(1.0)
|
||||
tracker.add(2.0)
|
||||
assert tracker.max() == 2.0
|
||||
|
||||
# Reset
|
||||
tracker.reset()
|
||||
|
||||
# Second batch
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.8
|
||||
assert tracker.percentile(0.5) <= 0.8
|
||||
|
||||
|
||||
# ====================== Type Conversion Tests ======================
|
||||
|
||||
|
||||
def test_add_with_integer(tracker):
|
||||
"""Test adding integer values."""
|
||||
tracker.add(5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_with_string_number(tracker):
|
||||
"""Test adding string representation of number."""
|
||||
tracker.add("3.14")
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_percentile_converts_q_to_float(tracker):
|
||||
"""Test percentile converts q parameter to float."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
|
||||
# Pass integer q
|
||||
result = tracker.percentile(1)
|
||||
assert result == 0.8
|
||||
@@ -0,0 +1,773 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC modeling module (RTCProcessor)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_enabled():
|
||||
"""Create RTC config with debug enabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
debug_maxlen=100,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_disabled():
|
||||
"""Create RTC config with debug disabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
|
||||
"""Create RTCProcessor with debug enabled."""
|
||||
return RTCProcessor(rtc_config_debug_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
|
||||
"""Create RTCProcessor with debug disabled."""
|
||||
return RTCProcessor(rtc_config_debug_disabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_x_t():
|
||||
"""Create sample x_t tensor (batch, time, action_dim)."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prev_chunk():
|
||||
"""Create sample previous chunk tensor."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
|
||||
"""Test RTCProcessor initializes with debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_enabled)
|
||||
assert processor.rtc_config == rtc_config_debug_enabled
|
||||
assert processor.tracker is not None
|
||||
assert processor.tracker.enabled is True
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
|
||||
"""Test RTCProcessor initializes without debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_disabled)
|
||||
assert processor.rtc_config == rtc_config_debug_disabled
|
||||
assert processor.tracker is None
|
||||
|
||||
|
||||
# ====================== Tracker Proxy Methods Tests ======================
|
||||
|
||||
|
||||
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test track() forwards to tracker when enabled."""
|
||||
rtc_processor_debug_enabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
guidance_weight=2.0,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].time == 0.5
|
||||
|
||||
|
||||
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
|
||||
"""Test track() does nothing when tracker disabled."""
|
||||
# Should not raise error
|
||||
rtc_processor_debug_disabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
)
|
||||
|
||||
# Should return empty list
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test get_all_debug_steps() returns tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test get_all_debug_steps() returns empty list when disabled."""
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert steps == []
|
||||
assert isinstance(steps, list)
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
|
||||
"""Test is_debug_enabled() returns True when tracker enabled."""
|
||||
assert rtc_processor_debug_enabled.is_debug_enabled() is True
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
|
||||
"""Test is_debug_enabled() returns False when tracker disabled."""
|
||||
assert rtc_processor_debug_disabled.is_debug_enabled() is False
|
||||
|
||||
|
||||
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test reset_tracker() clears tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
|
||||
|
||||
rtc_processor_debug_enabled.reset_tracker()
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
|
||||
|
||||
|
||||
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test reset_tracker() doesn't error when tracker disabled."""
|
||||
rtc_processor_debug_disabled.reset_tracker() # Should not raise
|
||||
|
||||
|
||||
# ====================== get_prefix_weights Tests ======================
|
||||
|
||||
|
||||
def test_get_prefix_weights_zeros_schedule():
|
||||
"""Test get_prefix_weights with ZEROS schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=10, total=20)
|
||||
|
||||
# First 5 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
assert torch.all(weights[5:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_ones_schedule():
|
||||
"""Test get_prefix_weights with ONES schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# First 15 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:15] == 1.0)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_linear_schedule():
|
||||
"""Test get_prefix_weights with LINEAR schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section (5:15) should be linearly decreasing from 1 to 0
|
||||
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
assert torch.allclose(weights[5:14], middle_weights)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_exp_schedule():
|
||||
"""Test get_prefix_weights with EXP schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section should be exponentially weighted
|
||||
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
|
||||
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_end():
|
||||
"""Test get_prefix_weights when start equals end."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_greater_than_end():
|
||||
"""Test get_prefix_weights when start > end (gets clamped)."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# start > end should use min(start, end) = end
|
||||
weights = processor.get_prefix_weights(start=15, end=10, total=20)
|
||||
|
||||
# Should have ones up to end (10), then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
# ====================== Helper Method Tests ======================
|
||||
|
||||
|
||||
def test_linweights_with_end_equals_start():
|
||||
"""Test _linweights when end equals start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=10, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_linweights_with_end_less_than_start():
|
||||
"""Test _linweights when end < start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=15, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_add_trailing_zeros_normal():
|
||||
"""Test _add_trailing_zeros adds zeros correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
|
||||
result = processor._add_trailing_zeros(weights, total=10, end=5)
|
||||
|
||||
# Should add 5 zeros (total - end = 10 - 5 = 5)
|
||||
assert len(result) == 10
|
||||
assert torch.all(result[:5] == weights)
|
||||
assert torch.all(result[5:] == 0.0)
|
||||
|
||||
|
||||
def test_add_trailing_zeros_no_zeros_needed():
|
||||
"""Test _add_trailing_zeros when no zeros needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6])
|
||||
result = processor._add_trailing_zeros(weights, total=3, end=5)
|
||||
|
||||
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_normal():
|
||||
"""Test _add_leading_ones adds ones correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
result = processor._add_leading_ones(weights, start=3, total=10)
|
||||
|
||||
# Should add 3 ones at the start
|
||||
assert len(result) == 8
|
||||
assert torch.all(result[:3] == 1.0)
|
||||
assert torch.all(result[3:] == weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_no_ones_needed():
|
||||
"""Test _add_leading_ones when no ones needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4])
|
||||
result = processor._add_leading_ones(weights, start=0, total=10)
|
||||
|
||||
# ones_len = 0, so no ones added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_total():
|
||||
"""Test get_prefix_weights when start equals total."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 20
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_total_less_than_start():
|
||||
"""Test get_prefix_weights when total less than start."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=5)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 5
|
||||
assert torch.all(weights == 1.0)
|
||||
|
||||
|
||||
# ====================== denoise_step Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step without previous chunk (no guidance)."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
|
||||
# Mock denoiser that returns fixed velocity
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should return v_t unchanged (no guidance)
|
||||
expected = mock_denoiser(x_t)
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step with previous chunk applies guidance."""
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 20, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.5833],
|
||||
[1.3667],
|
||||
[1.1500],
|
||||
[0.9333],
|
||||
[0.7167],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_adds_batch_dimension():
|
||||
"""Test denoise_step handles 2D input by adding batch dimension."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# 2D input (no batch dimension)
|
||||
x_t = torch.randn(10, 6)
|
||||
prev_chunk = torch.randn(5, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Output should be 2D (batch dimension removed)
|
||||
assert result.ndim == 2
|
||||
assert result.shape == (10, 6)
|
||||
|
||||
|
||||
def test_denoise_step_uses_custom_execution_horizon():
|
||||
"""Test denoise_step uses custom execution_horizon parameter."""
|
||||
config = RTCConfig(execution_horizon=10)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 15, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
execution_horizon=15,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.6818],
|
||||
[1.5636],
|
||||
[1.4455],
|
||||
[1.3273],
|
||||
[1.2091],
|
||||
[1.0909],
|
||||
[0.9727],
|
||||
[0.8545],
|
||||
[0.7364],
|
||||
[0.6182],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_zero():
|
||||
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 20, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_with_real_denoise_step_partial():
|
||||
"""Test denoise_step with a real denoiser."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
batch_size = 10
|
||||
action_dim = 6
|
||||
chunk_size = 20
|
||||
|
||||
x_t = torch.ones(batch_size, chunk_size, action_dim)
|
||||
prev_chunk = torch.full((batch_size, chunk_size, action_dim), 0.1)
|
||||
|
||||
velocity_function = torch.nn.Sequential(
|
||||
torch.nn.Linear(action_dim, 1000),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(1000, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(256, action_dim),
|
||||
)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return velocity_function(x)
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, chunk_size, action_dim)
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_one():
|
||||
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(1.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should clamp to max_guidance_weight (no Inf)
|
||||
assert not torch.any(torch.isinf(result))
|
||||
|
||||
|
||||
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
|
||||
"""Test denoise_step tracks debug information when enabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_enabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
# Check tracked values
|
||||
step = steps[0]
|
||||
assert step.time == 0.5
|
||||
assert step.x1_t is not None
|
||||
assert step.correction is not None
|
||||
assert step.err is not None
|
||||
assert step.weights is not None
|
||||
assert step.guidance_weight is not None
|
||||
assert step.inference_delay == 5
|
||||
|
||||
|
||||
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step doesn't track when debug disabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should not track
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_full_workflow():
|
||||
"""Test complete denoise_step workflow."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# Simulate two denoising steps
|
||||
x_t1 = torch.randn(1, 50, 6)
|
||||
x_t2 = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.randn_like(x) * 0.1
|
||||
|
||||
# First step - no guidance
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t1,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.8),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Second step - with guidance
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t2,
|
||||
prev_chunk_left_over=result1,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.6),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Both should complete successfully
|
||||
assert result1.shape == (1, 50, 6)
|
||||
assert result2.shape == (1, 50, 6)
|
||||
|
||||
# Should have tracked one step (second one, first had no prev_chunk)
|
||||
steps = processor.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_denoise_step_with_cuda_tensors():
|
||||
"""Test denoise_step works with CUDA tensors."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6, device="cuda")
|
||||
prev_chunk = torch.randn(1, 50, 6, device="cuda")
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Result should be on CUDA
|
||||
assert result.device.type == "cuda"
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_deterministic_with_same_inputs():
|
||||
"""Test denoise_step produces same output with same inputs."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
torch.manual_seed(42)
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def deterministic_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
# Should produce identical results
|
||||
assert torch.allclose(result1, result2)
|
||||
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization():
|
||||
"""Test SmolVLA policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = SmolVLAPolicy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ SmolVLA RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
"""Test SmolVLA policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Instantiate policy
|
||||
policy = SmolVLAPolicy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
"""Test SmolVLA policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_validation_rules():
|
||||
"""Test SmolVLA policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
Reference in New Issue
Block a user