mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
profile
This commit is contained in:
@@ -0,0 +1,263 @@
|
||||
# RTC Profiling Guide
|
||||
|
||||
This guide explains how to profile RTC (Real-Time Chunking) performance to identify bottlenecks and understand why RTC might be slower than expected.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Profile with Real Robot (Profiled Version)
|
||||
|
||||
Use `eval_with_real_robot_profiled.py` to profile actual robot execution:
|
||||
|
||||
```bash
|
||||
# With RTC enabled
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
|
||||
# Without RTC for comparison
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
**Output**: At the end of execution, you'll see a detailed breakdown of timing for each component:
|
||||
- `get_actions.policy_inference` - Time spent in policy inference
|
||||
- `get_actions.preprocessing` - Time spent preprocessing observations
|
||||
- `get_actions.postprocessing` - Time spent postprocessing actions
|
||||
- `get_actions.action_queue_merge` - Time spent merging actions with RTC
|
||||
- `robot.get_observation` - Time to get observations from robot
|
||||
- `robot.send_action` - Time to send actions to robot
|
||||
- And more...
|
||||
|
||||
### 2. Profile Without Robot (Comparison Script)
|
||||
|
||||
Use `profile_rtc_comparison.py` to profile just the policy inference without needing a robot:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Output**: Side-by-side comparison of performance with and without RTC, including:
|
||||
- Mean/min/max inference times
|
||||
- Throughput (iterations per second)
|
||||
- Verdict on whether RTC is faster or slower
|
||||
|
||||
### 3. Enable Detailed Method-Level Profiling
|
||||
|
||||
For even more granular profiling, add the `--enable_detailed_profiling` flag:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--enable_detailed_profiling
|
||||
```
|
||||
|
||||
This will show timing for individual methods within the policy.
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
### Key Metrics to Look At
|
||||
|
||||
1. **get_actions.policy_inference** - This should be the largest component
|
||||
- If RTC is enabled, this includes the RTC guidance overhead
|
||||
- Compare this with/without RTC to see the overhead
|
||||
|
||||
2. **get_actions.preprocessing** - Image preprocessing and normalization
|
||||
- Should be relatively fast
|
||||
- If slow, consider optimizing image processing
|
||||
|
||||
3. **get_actions.postprocessing** - Action denormalization
|
||||
- Should be minimal
|
||||
- If slow, check postprocessor implementation
|
||||
|
||||
4. **get_actions.action_queue_merge** - RTC-specific merging logic
|
||||
- Only present when RTC is enabled
|
||||
- If this is taking significant time, the RTC algorithm may need optimization
|
||||
|
||||
5. **robot.get_observation** - Robot communication overhead
|
||||
- If slow, check camera/sensor latency
|
||||
- Consider reducing image resolution
|
||||
|
||||
6. **robot.send_action** - Action execution overhead
|
||||
- Should be very fast
|
||||
- If slow, check robot communication
|
||||
|
||||
### Expected Performance
|
||||
|
||||
For a typical Pi0 policy on Apple Silicon (MPS):
|
||||
- **Without RTC**: ~100-200ms per inference
|
||||
- **With RTC**: Should be similar or slightly faster due to action reuse
|
||||
- **Preprocessing**: ~5-20ms depending on number of cameras
|
||||
- **Postprocessing**: ~1-5ms
|
||||
|
||||
If RTC is significantly slower, likely causes:
|
||||
1. **RTC overhead exceeds benefits** - The guidance computation is expensive
|
||||
2. **Execution horizon too small** - Not reusing enough actions to amortize overhead
|
||||
3. **No compilation** - Try with `--use_torch_compile`
|
||||
4. **Large prev_actions buffer** - Copying/processing previous actions is slow
|
||||
|
||||
## Profiling Your Own Code
|
||||
|
||||
### Using the Profiling Decorator
|
||||
|
||||
Add profiling to your own methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling, print_profiling_summary
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Decorate methods you want to profile
|
||||
@profile_method
|
||||
def my_slow_function(x):
|
||||
# ... your code ...
|
||||
return result
|
||||
|
||||
# At end of execution
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Using Profile Context Manager
|
||||
|
||||
For profiling specific code blocks:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_section, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("data_loading"):
|
||||
data = load_data()
|
||||
|
||||
with profile_section("model_inference"):
|
||||
output = model(data)
|
||||
```
|
||||
|
||||
### Adding Profiling to Policy Methods
|
||||
|
||||
To profile specific parts of the Pi0 policy, you can add decorators:
|
||||
|
||||
```python
|
||||
# In src/lerobot/policies/pi0/modeling_pi0.py
|
||||
from lerobot.utils.profiling import profile_method, profile_section
|
||||
|
||||
class Pi0Policy:
|
||||
@profile_method
|
||||
def predict_action_chunk(self, obs, inference_delay=0, prev_chunk_left_over=None):
|
||||
# ... existing code ...
|
||||
pass
|
||||
|
||||
def _generate_actions_with_rtc(self, ...):
|
||||
with profile_section("rtc.guidance_computation"):
|
||||
# ... guidance code ...
|
||||
pass
|
||||
|
||||
with profile_section("rtc.action_merging"):
|
||||
# ... merging code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## Analyzing Results
|
||||
|
||||
### Comparison Checklist
|
||||
|
||||
When comparing RTC vs non-RTC performance, check:
|
||||
|
||||
- [ ] Is `policy_inference` time higher with RTC?
|
||||
- [ ] Is `action_queue_merge` taking significant time?
|
||||
- [ ] Are you running enough iterations to amortize warmup?
|
||||
- [ ] Is torch.compile enabled for fair comparison?
|
||||
- [ ] Is the execution horizon large enough? (should be >= 10-20)
|
||||
- [ ] Are you testing on the same hardware/device?
|
||||
|
||||
### Common Bottlenecks
|
||||
|
||||
1. **Image preprocessing dominates**
|
||||
- Solution: Reduce image resolution, use fewer cameras, or optimize preprocessing
|
||||
|
||||
2. **Action queue operations are slow**
|
||||
- Solution: Review queue implementation, consider using ring buffer
|
||||
|
||||
3. **RTC guidance is expensive**
|
||||
- Solution: Reduce guidance weight, simplify guidance computation, use torch.compile
|
||||
|
||||
4. **Robot communication is slow**
|
||||
- Solution: Increase baud rate, reduce action frequency, optimize protocol
|
||||
|
||||
5. **Memory allocation overhead**
|
||||
- Solution: Pre-allocate buffers, reuse tensors, avoid unnecessary copies
|
||||
|
||||
## Advanced: Adding Custom Metrics
|
||||
|
||||
You can add custom timing metrics to the profiled script:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import record_timing
|
||||
|
||||
start = time.perf_counter()
|
||||
# ... your code ...
|
||||
duration = time.perf_counter() - start
|
||||
record_timing("my_custom_metric", duration)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Profiling shows RTC is slower by >50%
|
||||
|
||||
1. Check if torch.compile is enabled: `--use_torch_compile`
|
||||
2. Increase execution horizon: `--rtc.execution_horizon=30`
|
||||
3. Verify inference_delay is calculated correctly
|
||||
4. Profile with `--enable_detailed_profiling` to find exact bottleneck
|
||||
|
||||
### Profiling output is empty
|
||||
|
||||
1. Make sure profiling is enabled with `enable_profiling()`
|
||||
2. Verify you're running enough iterations (at least 10)
|
||||
3. Check that code is actually executing (not short-circuited)
|
||||
|
||||
### Inconsistent results between runs
|
||||
|
||||
1. Run more iterations: `--num_iterations=100`
|
||||
2. Increase warmup iterations
|
||||
3. Check for thermal throttling on device
|
||||
4. Ensure no other processes competing for resources
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run both profiling scripts (with/without robot)
|
||||
2. Compare timing breakdowns
|
||||
3. Identify the largest bottleneck
|
||||
4. Focus optimization efforts on that component
|
||||
5. Re-run profiling to verify improvements
|
||||
|
||||
## Questions?
|
||||
|
||||
If profiling reveals unexpected bottlenecks or you need help interpreting results, please share:
|
||||
- The full profiling output
|
||||
- Your configuration (RTC enabled/disabled, execution horizon, etc.)
|
||||
- Hardware specs (device type, memory, etc.)
|
||||
- Policy type and size
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
# RTC Profiling - Quick Start
|
||||
|
||||
Quick reference for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 🚀 Quick Commands
|
||||
|
||||
### 1. Profile with Real Robot
|
||||
|
||||
```bash
|
||||
# With RTC enabled (profiled version)
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0}, front: {type: opencv, index_or_path: 1}}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
### 2. Compare RTC vs No-RTC (No Robot Needed)
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
### 3. Detailed RTC Method Profiling
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
## 📊 What Each Tool Does
|
||||
|
||||
| Tool | Purpose | Needs Robot? |
|
||||
|------|---------|--------------|
|
||||
| `eval_with_real_robot_profiled.py` | Profile actual robot execution with RTC | ✅ Yes |
|
||||
| `profile_rtc_comparison.py` | Compare RTC vs no-RTC side-by-side | ❌ No |
|
||||
| `profile_pi0_rtc_detailed.py` | Deep dive into RTC internals | ❌ No |
|
||||
|
||||
## 🔍 Key Metrics to Watch
|
||||
|
||||
### Overall Performance
|
||||
- **iteration.policy_inference** - Total policy inference time
|
||||
- **iteration.preprocessing** - Image preprocessing time
|
||||
- **iteration.postprocessing** - Action denormalization time
|
||||
|
||||
### RTC-Specific (with `--enable_rtc_profiling`)
|
||||
- **rtc.denoise_step.base_denoising** - Time without RTC overhead
|
||||
- **rtc.denoise_step.autograd_correction** - Gradient computation time
|
||||
- **rtc.denoise_step.guidance_computation** - Total RTC guidance overhead
|
||||
|
||||
### Robot Communication
|
||||
- **robot.get_observation** - Time to get robot state
|
||||
- **robot.send_action** - Time to send action command
|
||||
|
||||
## 🎯 Quick Diagnosis
|
||||
|
||||
### RTC is slower than expected?
|
||||
|
||||
1. **Check if torch.compile is enabled**
|
||||
```bash
|
||||
# Add this flag
|
||||
--use_torch_compile
|
||||
```
|
||||
|
||||
2. **Try larger execution horizon**
|
||||
```bash
|
||||
# Increase to amortize RTC overhead
|
||||
--rtc.execution_horizon=30
|
||||
```
|
||||
|
||||
3. **Profile to find bottleneck**
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
### Preprocessing is slow?
|
||||
|
||||
- Reduce image resolution in robot config
|
||||
- Use fewer cameras
|
||||
- Check camera FPS settings
|
||||
|
||||
### Policy inference is slow?
|
||||
|
||||
- Enable torch.compile
|
||||
- Check device (MPS vs CUDA vs CPU)
|
||||
- Try smaller model if available
|
||||
|
||||
## 📈 Expected Performance
|
||||
|
||||
### Typical timings on Apple Silicon (MPS):
|
||||
|
||||
| Component | Time (ms) | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| Policy inference | 100-200 | Depends on model size |
|
||||
| Preprocessing | 5-20 | Depends on #cameras |
|
||||
| Postprocessing | 1-5 | Usually fast |
|
||||
| RTC overhead | 10-50 | Should be < 50% of base |
|
||||
|
||||
### When RTC helps:
|
||||
- ✅ Execution horizon ≥ 10
|
||||
- ✅ Inference time > action execution rate
|
||||
- ✅ Using torch.compile
|
||||
- ✅ Proper inference_delay calculation
|
||||
|
||||
### When RTC might not help:
|
||||
- ❌ Very fast inference already
|
||||
- ❌ Small execution horizon (< 5)
|
||||
- ❌ No compilation (interpreted mode)
|
||||
- ❌ Inference delay not accounted for
|
||||
|
||||
## 🛠️ Adding Profiling to Your Code
|
||||
|
||||
### Quick snippet:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary, profile_section
|
||||
|
||||
# Enable at start
|
||||
enable_profiling()
|
||||
|
||||
# Profile sections
|
||||
with profile_section("my_operation"):
|
||||
# ... your code ...
|
||||
pass
|
||||
|
||||
# Print at end
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Profile specific methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method
|
||||
|
||||
@profile_method
|
||||
def my_slow_function():
|
||||
# ... your code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## 📝 Example Output
|
||||
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms)
|
||||
--------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23
|
||||
iteration.preprocessing 20 12.45
|
||||
rtc.denoise_step.guidance_computation 200 15.67
|
||||
rtc.denoise_step.autograd_correction 200 8.23
|
||||
rtc.denoise_step.base_denoising 200 120.45
|
||||
================================================================================
|
||||
```
|
||||
|
||||
## 🚨 Common Issues
|
||||
|
||||
### "No profiling data available"
|
||||
- Did you call `enable_profiling()`?
|
||||
- Running enough iterations?
|
||||
|
||||
### Inconsistent results
|
||||
- Increase `--num_iterations`
|
||||
- Check for thermal throttling
|
||||
- Close other applications
|
||||
|
||||
### Can't find bottleneck
|
||||
- Enable `--enable_rtc_profiling` for detailed breakdown
|
||||
- Check both preprocessing and inference
|
||||
- Compare with and without RTC
|
||||
|
||||
## 📖 More Details
|
||||
|
||||
See `PROFILING_GUIDE.md` for comprehensive documentation.
|
||||
|
||||
## 🤔 Still Slow?
|
||||
|
||||
1. Run comparison: `profile_rtc_comparison.py`
|
||||
2. Run detailed profiling: `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Share output for help (include device, model, settings)
|
||||
|
||||
## ✅ Quick Checklist
|
||||
|
||||
Before asking for help, verify:
|
||||
|
||||
- [ ] Ran comparison script (with/without RTC)
|
||||
- [ ] Tried torch.compile
|
||||
- [ ] Tested different execution horizons (10, 20, 30)
|
||||
- [ ] Profiled with detailed RTC profiling
|
||||
- [ ] Checked preprocessing vs inference split
|
||||
- [ ] Verified hardware (device type, thermal state)
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
# RTC Profiling Toolkit
|
||||
|
||||
Complete toolkit for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 📦 What's Included
|
||||
|
||||
### Scripts
|
||||
|
||||
1. **`eval_with_real_robot_profiled.py`**
|
||||
- Profiled version of the real robot eval script
|
||||
- Adds timing measurements throughout execution
|
||||
- Works with actual robot hardware
|
||||
- Same usage as original but with profiling output
|
||||
|
||||
2. **`profile_rtc_comparison.py`**
|
||||
- Side-by-side comparison of RTC vs no-RTC
|
||||
- No robot needed (uses mock observations)
|
||||
- Shows clear verdict on whether RTC is helping
|
||||
- Great for quick performance checks
|
||||
|
||||
3. **`profile_pi0_rtc_detailed.py`**
|
||||
- Most detailed profiling available
|
||||
- Can enable RTC method-level profiling
|
||||
- Provides insights and recommendations
|
||||
- Perfect for deep-dive investigations
|
||||
|
||||
4. **`add_rtc_profiling.py`**
|
||||
- Monkey-patching utility for RTC internals
|
||||
- Profiles individual RTC operations
|
||||
- Can be applied without modifying source
|
||||
- Shows exactly where RTC spends time
|
||||
|
||||
### Utilities
|
||||
|
||||
5. **`src/lerobot/utils/profiling.py`**
|
||||
- Core profiling utilities
|
||||
- Decorators for method profiling
|
||||
- Context managers for code blocks
|
||||
- Statistics collection and reporting
|
||||
|
||||
### Documentation
|
||||
|
||||
6. **`PROFILING_GUIDE.md`** - Comprehensive guide
|
||||
7. **`PROFILING_QUICK_START.md`** - Quick reference
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Step 1: Compare Performance
|
||||
|
||||
Run this first to see if RTC is actually slower:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
COMPARISON SUMMARY
|
||||
================================================================================
|
||||
Metric Without RTC With RTC Difference
|
||||
--------------------------------------------------------------------------------
|
||||
Mean time (ms) 150.23 165.45 +15.22
|
||||
Throughput (iter/s) 6.66 6.05 -0.61
|
||||
================================================================================
|
||||
VERDICT
|
||||
✗ RTC is SLOWER by 10.1%
|
||||
Mean time increased by 15.22 ms
|
||||
|
||||
Possible reasons:
|
||||
- RTC overhead exceeds benefits at current execution horizon
|
||||
- No torch.compile enabled
|
||||
```
|
||||
|
||||
### Step 2: Identify Bottleneck
|
||||
|
||||
If RTC is slower, find out why:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms) Total (s)
|
||||
------------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23 3.00
|
||||
rtc.denoise_step.guidance_computation 200 15.67 3.13
|
||||
rtc.denoise_step.autograd_correction 200 8.23 1.65
|
||||
iteration.preprocessing 20 12.45 0.25
|
||||
================================================================================
|
||||
|
||||
KEY INSIGHTS
|
||||
================================================================================
|
||||
Time breakdown:
|
||||
Policy inference: 150.23 ms (87.2%)
|
||||
Preprocessing: 12.45 ms (7.2%)
|
||||
Postprocessing: 2.10 ms (1.2%)
|
||||
|
||||
RTC breakdown:
|
||||
Base denoising: 120.45 ms
|
||||
Guidance compute: 15.67 ms
|
||||
Autograd correct: 8.23 ms
|
||||
RTC overhead: 23.90 ms (19.8% of base)
|
||||
|
||||
Recommendations:
|
||||
⚠ RTC autograd overhead is significant
|
||||
→ This is expected, but consider increasing execution_horizon
|
||||
→ Try torch.compile if not already enabled
|
||||
💡 torch.compile not enabled
|
||||
→ Try --use_torch_compile for potential speedup
|
||||
================================================================================
|
||||
```
|
||||
|
||||
### Step 3: Try Optimizations
|
||||
|
||||
Based on recommendations:
|
||||
|
||||
```bash
|
||||
# Try with torch.compile
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--use_torch_compile
|
||||
|
||||
# Try larger execution horizon
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=30
|
||||
```
|
||||
|
||||
### Step 4: Profile Real Robot (Optional)
|
||||
|
||||
Test with actual hardware:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{...}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
## 🎯 Common Scenarios
|
||||
|
||||
### "RTC is 2x slower!"
|
||||
|
||||
This usually means:
|
||||
- RTC overhead is high but not getting benefits
|
||||
- Need to enable torch.compile
|
||||
- Execution horizon too small
|
||||
- Inference delay not calculated correctly
|
||||
|
||||
**Try:**
|
||||
1. `--use_torch_compile`
|
||||
2. Increase `--execution_horizon` to 30+
|
||||
3. Check inference_delay calculation
|
||||
|
||||
### "RTC is only slightly slower"
|
||||
|
||||
This is expected! RTC overhead is about 10-30% typically.
|
||||
The benefit comes during **execution**, not single inference:
|
||||
- Actions are reused across chunks
|
||||
- Overall system latency is reduced
|
||||
- Robot gets smoother actions
|
||||
|
||||
### "Want to optimize specific part"
|
||||
|
||||
Use the profiling utilities:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, profile_section, print_profiling_summary
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("my_custom_operation"):
|
||||
# Your code here
|
||||
pass
|
||||
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
## 📊 Understanding Results
|
||||
|
||||
### Key Metrics
|
||||
|
||||
**Policy Inference Time**
|
||||
- Time for forward pass through model
|
||||
- Should be largest component (70-90%)
|
||||
- Includes RTC guidance if enabled
|
||||
|
||||
**Preprocessing Time**
|
||||
- Image normalization, resizing
|
||||
- Should be < 20% of total
|
||||
- If high: reduce image resolution
|
||||
|
||||
**RTC Guidance Overhead**
|
||||
- Extra time for RTC guidance computation
|
||||
- Typically 10-30% of base inference
|
||||
- If > 50%: RTC may not be beneficial at current settings
|
||||
|
||||
**Autograd Correction**
|
||||
- Time computing gradients for RTC
|
||||
- Usually 5-15% of base inference
|
||||
- Can be reduced with torch.compile
|
||||
|
||||
### Expected Ranges (Apple Silicon MPS)
|
||||
|
||||
| Metric | Good | Acceptable | Poor |
|
||||
|--------|------|------------|------|
|
||||
| Policy inference | 100-150ms | 150-250ms | >250ms |
|
||||
| Preprocessing | <20ms | 20-50ms | >50ms |
|
||||
| RTC overhead | 10-30% | 30-50% | >50% |
|
||||
|
||||
## 🔧 Optimization Guide
|
||||
|
||||
### If RTC overhead is too high:
|
||||
|
||||
1. **Enable compilation:**
|
||||
```bash
|
||||
--use_torch_compile
|
||||
```
|
||||
Expected improvement: 20-40% faster
|
||||
|
||||
2. **Increase execution horizon:**
|
||||
```bash
|
||||
--execution_horizon=30 # or higher
|
||||
```
|
||||
Amortizes RTC cost over more actions
|
||||
|
||||
3. **Check guidance weight:**
|
||||
```python
|
||||
# In config
|
||||
rtc.max_guidance_weight=1.0 # try 0.5 for less overhead
|
||||
```
|
||||
|
||||
### If preprocessing is slow:
|
||||
|
||||
1. **Reduce image resolution:**
|
||||
```python
|
||||
# In robot config
|
||||
cameras={
|
||||
"gripper": {"width": 320, "height": 240} # instead of 640x480
|
||||
}
|
||||
```
|
||||
|
||||
2. **Use fewer cameras:**
|
||||
- Profile which cameras are essential
|
||||
- Remove unnecessary views
|
||||
|
||||
### If inference is generally slow:
|
||||
|
||||
1. Use torch.compile (if not already)
|
||||
2. Check device is correct (MPS vs CUDA)
|
||||
3. Verify model is in eval mode
|
||||
4. Check for unnecessary gradient tracking
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Empty profiling output
|
||||
```python
|
||||
# Make sure to enable profiling!
|
||||
from lerobot.utils.profiling import enable_profiling
|
||||
enable_profiling()
|
||||
```
|
||||
|
||||
### Inconsistent timings
|
||||
- Run more iterations (50-100)
|
||||
- Check thermal throttling
|
||||
- Close background apps
|
||||
- Use `--warmup_iterations=10`
|
||||
|
||||
### Can't find bottleneck
|
||||
1. Start with `profile_rtc_comparison.py`
|
||||
2. Then run `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Compare with/without RTC
|
||||
4. Check each component separately
|
||||
|
||||
## 📖 Full Documentation
|
||||
|
||||
- **`PROFILING_GUIDE.md`** - Complete reference with examples
|
||||
- **`PROFILING_QUICK_START.md`** - Quick commands and tips
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
If you're still experiencing issues:
|
||||
|
||||
1. Run comparison script and save output
|
||||
2. Run detailed profiling and save output
|
||||
3. Include:
|
||||
- Policy path
|
||||
- Device type
|
||||
- RTC settings (execution_horizon, etc.)
|
||||
- Hardware specs
|
||||
- Full profiling output
|
||||
|
||||
## 🎓 Learning More
|
||||
|
||||
### Profiling your own code:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
@profile_method
|
||||
def my_function():
|
||||
# Automatically profiled
|
||||
pass
|
||||
```
|
||||
|
||||
### RTC internals:
|
||||
|
||||
```python
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# Now RTC methods are profiled
|
||||
policy.predict_action_chunk(...)
|
||||
```
|
||||
|
||||
## ✨ Next Steps
|
||||
|
||||
1. Run `profile_rtc_comparison.py` to establish baseline
|
||||
2. Use `profile_pi0_rtc_detailed.py` to find bottlenecks
|
||||
3. Apply optimizations (torch.compile, larger horizon)
|
||||
4. Re-run comparison to verify improvements
|
||||
5. Test with real robot using profiled version
|
||||
|
||||
Happy profiling! 🚀
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to add profiling instrumentation to RTCProcessor.
|
||||
|
||||
This script shows which methods to profile in the RTC code to identify bottlenecks.
|
||||
You can either:
|
||||
1. Apply these changes directly to modeling_rtc.py
|
||||
2. Use monkey patching to add profiling without modifying source
|
||||
3. Use as reference for manual instrumentation
|
||||
|
||||
Usage:
|
||||
# Option 1: Monkey patch (no source changes)
|
||||
python examples/rtc/add_rtc_profiling.py
|
||||
|
||||
# Option 2: Apply changes to source
|
||||
# Copy the profiled methods below into src/lerobot/policies/rtc/modeling_rtc.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.profiling import ProfileContext, enable_profiling, is_profiling_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def profile_denoise_step(self, x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon=None) -> Tensor:
|
||||
"""Profiled version of denoise_step."""
|
||||
|
||||
if not is_profiling_enabled():
|
||||
# Call original implementation if profiling disabled
|
||||
return self._original_denoise_step(x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon)
|
||||
|
||||
with ProfileContext("rtc.denoise_step.total"):
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
with ProfileContext("rtc.denoise_step.setup"):
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
# Padding
|
||||
with ProfileContext("rtc.denoise_step.padding"):
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
# Get prefix weights
|
||||
with ProfileContext("rtc.denoise_step.get_prefix_weights"):
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# Main RTC guidance computation
|
||||
with ProfileContext("rtc.denoise_step.guidance_computation"):
|
||||
with torch.enable_grad():
|
||||
# Base denoising
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
# Compute x1_t
|
||||
with ProfileContext("rtc.denoise_step.compute_x1_t"):
|
||||
x1_t = x_t - time * v_t
|
||||
|
||||
# Compute error
|
||||
with ProfileContext("rtc.denoise_step.compute_error"):
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
|
||||
# Compute correction via autograd
|
||||
with ProfileContext("rtc.denoise_step.autograd_correction"):
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
# Compute guidance weight
|
||||
with ProfileContext("rtc.denoise_step.compute_guidance_weight"):
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
# Apply guidance
|
||||
with ProfileContext("rtc.denoise_step.apply_guidance"):
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Cleanup
|
||||
with ProfileContext("rtc.denoise_step.cleanup"):
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def monkey_patch_rtc_profiling():
|
||||
"""Apply profiling to RTCProcessor via monkey patching.
|
||||
|
||||
This modifies the RTCProcessor class at runtime to add profiling
|
||||
without changing source files.
|
||||
"""
|
||||
logger.info("Applying RTC profiling monkey patch...")
|
||||
|
||||
# Save original method
|
||||
RTCProcessor._original_denoise_step = RTCProcessor.denoise_step
|
||||
|
||||
# Replace with profiled version
|
||||
RTCProcessor.denoise_step = profile_denoise_step
|
||||
|
||||
logger.info("✓ RTC profiling enabled")
|
||||
|
||||
|
||||
def print_usage():
|
||||
"""Print usage instructions."""
|
||||
print("\n" + "="*80)
|
||||
print("RTC PROFILING INSTRUMENTATION")
|
||||
print("="*80)
|
||||
print("\nThis script provides profiling for RTCProcessor methods.")
|
||||
print("\nOption 1: Monkey Patch (Recommended)")
|
||||
print("-" * 40)
|
||||
print("Add to your script:")
|
||||
print("""
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# ... run your code ...
|
||||
|
||||
# Print results
|
||||
print_profiling_summary()
|
||||
""")
|
||||
|
||||
print("\nOption 2: Manual Source Modification")
|
||||
print("-" * 40)
|
||||
print("1. Copy profile_denoise_step() from this file")
|
||||
print("2. Replace denoise_step() in src/lerobot/policies/rtc/modeling_rtc.py")
|
||||
print("3. Add profiling imports at top of file")
|
||||
|
||||
print("\nKey Metrics to Watch:")
|
||||
print("-" * 40)
|
||||
print("- rtc.denoise_step.base_denoising - Time for base policy inference")
|
||||
print("- rtc.denoise_step.autograd_correction - Time computing gradients")
|
||||
print("- rtc.denoise_step.guidance_computation - Total guidance overhead")
|
||||
print("- rtc.denoise_step.get_prefix_weights - Time computing weights")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_usage()
|
||||
|
||||
@@ -0,0 +1,631 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Profiled version of eval_with_real_robot.py for performance analysis.
|
||||
|
||||
This version adds detailed timing measurements for:
|
||||
- Policy inference
|
||||
- Preprocessing
|
||||
- Postprocessing
|
||||
- Action queue operations
|
||||
- Robot communication
|
||||
- Thread execution times
|
||||
|
||||
Usage: Same as eval_with_real_robot.py but with profiling output.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileTimer:
|
||||
"""Context manager and utility class for timing code sections."""
|
||||
|
||||
def __init__(self, name: str, stats_dict: dict):
|
||||
self.name = name
|
||||
self.stats_dict = stats_dict
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
elapsed = time.perf_counter() - self.start_time
|
||||
if self.name not in self.stats_dict:
|
||||
self.stats_dict[self.name] = []
|
||||
self.stats_dict[self.name].append(elapsed)
|
||||
|
||||
|
||||
class ProfilingStats:
|
||||
"""Global profiling statistics collector."""
|
||||
|
||||
def __init__(self):
|
||||
self.stats = defaultdict(list)
|
||||
self.lock = Lock()
|
||||
|
||||
def record(self, name: str, duration: float):
|
||||
with self.lock:
|
||||
self.stats[name].append(duration)
|
||||
|
||||
def timer(self, name: str):
|
||||
"""Return a context manager for timing."""
|
||||
return ProfileTimer(name, self.stats)
|
||||
|
||||
def get_summary(self) -> dict[str, dict[str, float]]:
|
||||
"""Get summary statistics for all timings."""
|
||||
with self.lock:
|
||||
summary = {}
|
||||
for name, times in self.stats.items():
|
||||
if times:
|
||||
summary[name] = {
|
||||
"count": len(times),
|
||||
"mean": sum(times) / len(times),
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"total": sum(times),
|
||||
}
|
||||
return summary
|
||||
|
||||
def print_summary(self):
|
||||
"""Print formatted summary of all timings."""
|
||||
summary = self.get_summary()
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("PROFILING SUMMARY")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Sort by total time (descending)
|
||||
sorted_items = sorted(summary.items(), key=lambda x: x[1]["total"], reverse=True)
|
||||
|
||||
for name, stats in sorted_items:
|
||||
logger.info(f"\n{name}:")
|
||||
logger.info(f" Count: {stats['count']}")
|
||||
logger.info(f" Mean: {stats['mean']*1000:.2f} ms")
|
||||
logger.info(f" Min: {stats['min']*1000:.2f} ms")
|
||||
logger.info(f" Max: {stats['max']*1000:.2f} ms")
|
||||
logger.info(f" Total: {stats['total']:.2f} s")
|
||||
logger.info(f" Hz: {stats['count']/stats['total']:.2f}")
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
|
||||
|
||||
# Global profiling stats
|
||||
profiling_stats = ProfilingStats()
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with profiling_stats.timer("robot.get_observation"):
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with profiling_stats.timer("robot.send_action"):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy with profiling.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
with profiling_stats.timer("get_actions.total_iteration"):
|
||||
inference_count += 1
|
||||
logger.info(f"[GET_ACTIONS] Starting inference #{inference_count}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
|
||||
with profiling_stats.timer("get_actions.get_prev_actions"):
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
# Get observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
with profiling_stats.timer("get_actions.robot_obs_processing"):
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
# Build dataset frame
|
||||
with profiling_stats.timer("get_actions.build_dataset_frame"):
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
with profiling_stats.timer("get_actions.tensor_conversion"):
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
# Preprocessing
|
||||
with profiling_stats.timer("get_actions.preprocessing"):
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Policy inference
|
||||
with profiling_stats.timer("get_actions.policy_inference"):
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Clone for RTC
|
||||
with profiling_stats.timer("get_actions.clone_actions"):
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
# Postprocessing
|
||||
with profiling_stats.timer("get_actions.postprocessing"):
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
# Update latency tracker
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
logger.info(
|
||||
f"[GET_ACTIONS] Inference #{inference_count} completed in {new_latency*1000:.2f}ms "
|
||||
f"(delay={new_delay} chunks)"
|
||||
)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, "
|
||||
"It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
# Merge into action queue
|
||||
with profiling_stats.timer("get_actions.action_queue_merge"):
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot with profiling.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with profiling_stats.timer("actor.total_iteration"):
|
||||
# Get action from queue
|
||||
with profiling_stats.timer("actor.queue_get"):
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
# Process action
|
||||
with profiling_stats.timer("actor.action_processing"):
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
|
||||
# Send to robot (includes robot.send_action timing)
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
# Sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, (action_interval - dt_s) - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with profiling."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
logger.info("=" * 80)
|
||||
logger.info("PROFILING MODE ENABLED")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processor
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
# Print profiling summary
|
||||
profiling_stats.print_summary()
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
|
||||
@@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Comprehensive profiling script for Pi0 with RTC.
|
||||
|
||||
This script demonstrates how to use all the profiling tools to identify
|
||||
bottlenecks in Pi0 policy inference with RTC enabled.
|
||||
|
||||
It profiles:
|
||||
1. Overall inference time
|
||||
2. RTC-specific operations (guidance, weights, etc.)
|
||||
3. Preprocessing/postprocessing
|
||||
4. Individual method timings
|
||||
|
||||
Usage:
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
ProfileContext,
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
# Import monkey patching for RTC profiling
|
||||
try:
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
except ImportError:
|
||||
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
|
||||
monkey_patch_rtc_profiling = None
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_observation(policy_config, device: str) -> dict:
|
||||
"""Create a mock observation matching policy requirements.
|
||||
|
||||
Args:
|
||||
policy_config: Policy configuration
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
obs = {}
|
||||
|
||||
# Create mock state observation
|
||||
state_dim = 10 # Typical robot state dimension
|
||||
obs["observation.state"] = torch.randn(1, state_dim, device=device)
|
||||
|
||||
# Create mock images if needed
|
||||
# For Pi0, we typically need at least one image
|
||||
image_height = 224
|
||||
image_width = 224
|
||||
|
||||
# Common image keys for Pi0
|
||||
image_keys = ["observation.images.gripper", "observation.images.front"]
|
||||
|
||||
for key in image_keys:
|
||||
# Images should be [B, C, H, W] and normalized to [0, 1]
|
||||
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
|
||||
|
||||
# Add task
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Add language tokens and attention mask (required for Pi0)
|
||||
# These are mock values - in real usage they come from tokenizer
|
||||
max_seq_len = 32
|
||||
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
|
||||
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_single_iteration(
|
||||
policy,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
observation: dict,
|
||||
prev_actions: torch.Tensor | None,
|
||||
use_rtc: bool,
|
||||
inference_delay: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
|
||||
"""Profile a single inference iteration.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
preprocessor: Observation preprocessor
|
||||
postprocessor: Action postprocessor
|
||||
observation: Input observation
|
||||
prev_actions: Previous action chunk (for RTC)
|
||||
use_rtc: Whether RTC is enabled
|
||||
inference_delay: Inference delay in timesteps
|
||||
|
||||
Returns:
|
||||
Tuple of (actions, new_prev_actions, timings)
|
||||
"""
|
||||
timings = {}
|
||||
|
||||
with ProfileContext("iteration.total"):
|
||||
# Preprocessing
|
||||
with ProfileContext("iteration.preprocessing"):
|
||||
preprocessed_obs = preprocessor(observation)
|
||||
|
||||
# Policy inference
|
||||
with ProfileContext("iteration.policy_inference"):
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
else:
|
||||
actions = policy.predict_action_chunk(preprocessed_obs)
|
||||
|
||||
# Clone for next iteration (if RTC)
|
||||
new_prev_actions = None
|
||||
if use_rtc:
|
||||
with ProfileContext("iteration.prepare_prev_actions"):
|
||||
execution_horizon = policy.config.rtc_config.execution_horizon
|
||||
if actions.shape[1] > execution_horizon:
|
||||
new_prev_actions = actions[:, execution_horizon:].clone()
|
||||
|
||||
# Postprocessing
|
||||
with ProfileContext("iteration.postprocessing"):
|
||||
processed_actions = postprocessor(actions)
|
||||
|
||||
return processed_actions, new_prev_actions, timings
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
|
||||
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
|
||||
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
|
||||
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
|
||||
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
|
||||
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("DETAILED PI0 RTC PROFILING")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Policy: {args.policy_path}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Iterations: {args.num_iterations}")
|
||||
logger.info(f"Execution Horizon: {args.execution_horizon}")
|
||||
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Apply RTC profiling if requested
|
||||
if args.enable_rtc_profiling:
|
||||
if monkey_patch_rtc_profiling is not None:
|
||||
monkey_patch_rtc_profiling()
|
||||
logger.info("✓ Detailed RTC profiling enabled\n")
|
||||
else:
|
||||
logger.warning("⚠ Could not enable detailed RTC profiling\n")
|
||||
|
||||
# Load policy
|
||||
logger.info("Loading policy...")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy_class = get_policy_class(config.type)
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Configure RTC
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"✓ Policy loaded: {config.type}\n")
|
||||
|
||||
# Create preprocessor and postprocessor
|
||||
logger.info("Loading preprocessor/postprocessor...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=config,
|
||||
pretrained_path=args.policy_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": args.device},
|
||||
},
|
||||
)
|
||||
logger.info("✓ Preprocessor/postprocessor loaded\n")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(config, args.device)
|
||||
logger.info("✓ Mock observation created\n")
|
||||
|
||||
# Warmup
|
||||
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
for i in range(args.warmup_iterations):
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Clear warmup stats
|
||||
clear_profiling_stats()
|
||||
logger.info("✓ Warmup complete\n")
|
||||
|
||||
# Profiled run WITH RTC
|
||||
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
iteration_times = []
|
||||
|
||||
for i in range(args.num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Sync CUDA if needed
|
||||
if args.device.startswith("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
iteration_times.append(elapsed)
|
||||
|
||||
if (i + 1) % 5 == 0:
|
||||
logger.info(f" Completed {i+1}/{args.num_iterations}")
|
||||
|
||||
logger.info("✓ Profiling complete\n")
|
||||
|
||||
# Print summary statistics
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("ITERATION TIMING SUMMARY")
|
||||
logger.info("="*80)
|
||||
|
||||
times_arr = np.array(iteration_times)
|
||||
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
|
||||
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Print detailed profiling breakdown
|
||||
print_profiling_summary(sort_by="total")
|
||||
|
||||
# Print key insights
|
||||
stats = get_profiling_stats()
|
||||
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("KEY INSIGHTS")
|
||||
logger.info("="*80)
|
||||
|
||||
# Find bottlenecks
|
||||
if stats:
|
||||
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
|
||||
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
|
||||
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
|
||||
|
||||
total_time = policy_inference_time + preprocessing_time + postprocessing_time
|
||||
|
||||
if total_time > 0:
|
||||
logger.info(f"\nTime breakdown:")
|
||||
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
|
||||
|
||||
# RTC-specific insights
|
||||
if args.enable_rtc_profiling:
|
||||
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
|
||||
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
|
||||
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
|
||||
|
||||
if rtc_guidance > 0:
|
||||
logger.info(f"\nRTC breakdown:")
|
||||
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
|
||||
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
|
||||
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
|
||||
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
|
||||
|
||||
# Recommendations
|
||||
logger.info("\nRecommendations:")
|
||||
|
||||
if preprocessing_time > policy_inference_time * 0.3:
|
||||
logger.info(" ⚠ Preprocessing is taking >30% of time")
|
||||
logger.info(" → Consider reducing image resolution")
|
||||
logger.info(" → Consider using fewer cameras")
|
||||
|
||||
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
|
||||
logger.info(" ⚠ RTC autograd overhead is significant")
|
||||
logger.info(" → This is expected, but consider increasing execution_horizon")
|
||||
logger.info(" → Try torch.compile if not already enabled")
|
||||
|
||||
if not args.use_torch_compile:
|
||||
logger.info(" 💡 torch.compile not enabled")
|
||||
logger.info(" → Try --use_torch_compile for potential speedup")
|
||||
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nProfiling interrupted by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"\n\nError during profiling: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to compare performance with and without RTC enabled.
|
||||
|
||||
This script helps identify whether RTC is actually improving or degrading performance
|
||||
by running multiple inference passes and collecting detailed timing statistics.
|
||||
|
||||
Usage:
|
||||
# Profile with mock data (no robot needed)
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50
|
||||
|
||||
# Profile with specific RTC config
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileResults:
|
||||
"""Results from profiling run."""
|
||||
|
||||
mode: str # "with_rtc" or "without_rtc"
|
||||
mean_time: float
|
||||
std_time: float
|
||||
min_time: float
|
||||
max_time: float
|
||||
times: list[float]
|
||||
throughput: float # iterations per second
|
||||
|
||||
|
||||
def create_mock_observation(policy, device: str) -> dict:
|
||||
"""Create a mock observation for testing.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
# Get expected input shapes from policy config
|
||||
# This is a simplified version - adjust based on actual policy requirements
|
||||
obs = {}
|
||||
|
||||
# Mock image observations (if needed)
|
||||
if hasattr(policy.config, "input_shapes"):
|
||||
for key, shape in policy.config.input_shapes.items():
|
||||
if "image" in key:
|
||||
# Typical image shape: (batch, channels, height, width)
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
else:
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
|
||||
# Add task if needed
|
||||
if "task" in policy.config.__dict__ or hasattr(policy, "accepts_task"):
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Mock state observation
|
||||
obs["observation.state"] = torch.randn(1, 10, device=device) # Adjust size as needed
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_inference(
|
||||
policy, observation: dict, num_iterations: int, use_rtc: bool, execution_horizon: int = 10
|
||||
) -> ProfileResults:
|
||||
"""Profile policy inference with or without RTC.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
observation: Observation dictionary
|
||||
num_iterations: Number of inference iterations to run
|
||||
use_rtc: Whether to enable RTC
|
||||
execution_horizon: Execution horizon for RTC
|
||||
|
||||
Returns:
|
||||
ProfileResults with timing statistics
|
||||
"""
|
||||
mode = "with_rtc" if use_rtc else "without_rtc"
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info(f"Profiling: {mode.upper()}")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
# Configure RTC
|
||||
if use_rtc:
|
||||
policy.config.rtc_config.enabled = True
|
||||
policy.config.rtc_config.execution_horizon = execution_horizon
|
||||
policy.init_rtc_processor()
|
||||
else:
|
||||
policy.config.rtc_config.enabled = False
|
||||
|
||||
times = []
|
||||
prev_actions = None
|
||||
|
||||
# Warmup
|
||||
logger.info("Warming up (5 iterations)...")
|
||||
for _ in range(5):
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
_ = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
_ = policy.predict_action_chunk(observation)
|
||||
|
||||
# Actual profiling
|
||||
logger.info(f"Running {num_iterations} profiled iterations...")
|
||||
for i in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
# Simulate consuming some actions for next iteration
|
||||
if actions.shape[1] > execution_horizon:
|
||||
prev_actions = actions[:, execution_horizon:].clone()
|
||||
else:
|
||||
prev_actions = None
|
||||
else:
|
||||
actions = policy.predict_action_chunk(observation)
|
||||
|
||||
# Synchronize if using CUDA
|
||||
if observation["observation.state"].device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f" Completed {i+1}/{num_iterations} iterations")
|
||||
|
||||
# Calculate statistics
|
||||
times_arr = np.array(times)
|
||||
results = ProfileResults(
|
||||
mode=mode,
|
||||
mean_time=float(np.mean(times_arr)),
|
||||
std_time=float(np.std(times_arr)),
|
||||
min_time=float(np.min(times_arr)),
|
||||
max_time=float(np.max(times_arr)),
|
||||
times=times,
|
||||
throughput=num_iterations / sum(times),
|
||||
)
|
||||
|
||||
logger.info(f"\nResults for {mode}:")
|
||||
logger.info(f" Mean time: {results.mean_time*1000:.2f} ms")
|
||||
logger.info(f" Std dev: {results.std_time*1000:.2f} ms")
|
||||
logger.info(f" Min time: {results.min_time*1000:.2f} ms")
|
||||
logger.info(f" Max time: {results.max_time*1000:.2f} ms")
|
||||
logger.info(f" Throughput: {results.throughput:.2f} iter/s")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_results(results_without_rtc: ProfileResults, results_with_rtc: ProfileResults):
|
||||
"""Compare and print results from both runs.
|
||||
|
||||
Args:
|
||||
results_without_rtc: Results from run without RTC
|
||||
results_with_rtc: Results from run with RTC
|
||||
"""
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("COMPARISON SUMMARY")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
mean_diff = results_with_rtc.mean_time - results_without_rtc.mean_time
|
||||
mean_diff_pct = (mean_diff / results_without_rtc.mean_time) * 100
|
||||
|
||||
throughput_diff = results_with_rtc.throughput - results_without_rtc.throughput
|
||||
throughput_diff_pct = (throughput_diff / results_without_rtc.throughput) * 100
|
||||
|
||||
logger.info(f"\n{'Metric':<30} {'Without RTC':>15} {'With RTC':>15} {'Difference':>15}")
|
||||
logger.info("-" * 80)
|
||||
logger.info(
|
||||
f"{'Mean time (ms)':<30} "
|
||||
f"{results_without_rtc.mean_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.mean_time*1000:>15.2f} "
|
||||
f"{mean_diff*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Std dev (ms)':<30} "
|
||||
f"{results_without_rtc.std_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.std_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.std_time - results_without_rtc.std_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Min time (ms)':<30} "
|
||||
f"{results_without_rtc.min_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.min_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.min_time - results_without_rtc.min_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Max time (ms)':<30} "
|
||||
f"{results_without_rtc.max_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.max_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.max_time - results_without_rtc.max_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Throughput (iter/s)':<30} "
|
||||
f"{results_without_rtc.throughput:>15.2f} "
|
||||
f"{results_with_rtc.throughput:>15.2f} "
|
||||
f"{throughput_diff:>+15.2f}"
|
||||
)
|
||||
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("VERDICT")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
if mean_diff_pct < -5:
|
||||
logger.info(f"✓ RTC is FASTER by {abs(mean_diff_pct):.1f}%")
|
||||
logger.info(f" Mean time reduced by {abs(mean_diff)*1000:.2f} ms")
|
||||
elif mean_diff_pct > 5:
|
||||
logger.info(f"✗ RTC is SLOWER by {mean_diff_pct:.1f}%")
|
||||
logger.info(f" Mean time increased by {mean_diff*1000:.2f} ms")
|
||||
logger.info("\n Possible reasons:")
|
||||
logger.info(" - RTC overhead exceeds benefits at current execution horizon")
|
||||
logger.info(" - Inference delay calculation not accounting for RTC processing")
|
||||
logger.info(" - Additional tensor operations in RTC guidance")
|
||||
else:
|
||||
logger.info(f"≈ Performance is SIMILAR (difference: {mean_diff_pct:+.1f}%)")
|
||||
|
||||
logger.info(f"{'='*80}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile RTC performance")
|
||||
parser.add_argument(
|
||||
"--policy_path", type=str, required=True, help="Path to pretrained policy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="Device to run on (cuda/cpu/mps)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iterations", type=int, default=50, help="Number of inference iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--execution_horizon", type=int, default=10, help="RTC execution horizon"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_detailed_profiling",
|
||||
action="store_true",
|
||||
help="Enable detailed method-level profiling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_torch_compile", action="store_true", help="Use torch.compile for faster inference"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from {args.policy_path}")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
policy_class = get_policy_class(config.type)
|
||||
|
||||
# Set compile flag if needed
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Initialize RTC config
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"Policy loaded: {config.type}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Execution horizon: {args.execution_horizon}")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(policy, args.device)
|
||||
|
||||
# Enable detailed profiling if requested
|
||||
if args.enable_detailed_profiling:
|
||||
enable_profiling()
|
||||
logger.info("Detailed profiling enabled")
|
||||
|
||||
# Profile without RTC
|
||||
results_without_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=False,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITHOUT RTC):")
|
||||
print_profiling_summary()
|
||||
clear_profiling_stats()
|
||||
|
||||
# Profile with RTC
|
||||
results_with_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITH RTC):")
|
||||
print_profiling_summary()
|
||||
|
||||
# Compare results
|
||||
compare_results(results_without_rtc, results_with_rtc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user