Compare commits

...

65 Commits

Author SHA1 Message Date
Michel Aractingi c868777752 profile 2025-11-18 09:51:50 +01:00
Eugene Mironov 8847e75c55 Extract simulator logic from eval_with real robot and add proper headers to files 2025-11-16 19:04:24 +07:00
Eugene Mironov 8429d2ccfa fixup! fixup! Fixup eval with real robot 2025-11-16 18:35:08 +07:00
Eugene Mironov 6794ca2ba8 fixup! Fixup eval with real robot 2025-11-15 00:09:01 +07:00
Eugene Mironov 98c2152f08 Fixup eval with real robot 2025-11-15 00:09:01 +07:00
Eugene Mironov f92999aeb9 fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization 2025-11-15 00:09:01 +07:00
Eugene Mironov 5659c77988 Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov fd88a3acda Fix SmolVLA init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 6deabe4b71 Fix PI0.5 init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 2f3525c4a2 Add RTC initialization tests without config for PI0.5 and SmolVLA
Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov d04061def7 fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-15 00:09:01 +07:00
Eugene Mironov 07ee578c78 fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-15 00:09:01 +07:00
Eugene Mironov 636e2264c3 Fix test to use _rtc_enabled() instead of is_rtc_enabled()
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 5a4c168d92 fixup! Add one more test 2025-11-15 00:09:01 +07:00
Eugene Mironov 047f89cc2a Add one more test 2025-11-15 00:09:01 +07:00
Eugene Mironov 4d64733846 fixup! fixup! Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov 0c3ed6ca7a fixup! Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov 44322fa726 Add tests for flow matching models with RTC 2025-11-15 00:09:01 +07:00
Eugene Mironov e041634bee Add tests for modeling_rtc 2025-11-15 00:09:01 +07:00
Eugene Mironov 6b6c0623cc Fix tests 2025-11-15 00:09:01 +07:00
Eugene Mironov 6db3afca6f Silent validation 2025-11-15 00:09:01 +07:00
Eugene Mironov 433ccc9603 Update README 2025-11-15 00:09:01 +07:00
Eugene Mironov 9e92337f24 Add validatio at the end 2025-11-15 00:09:01 +07:00
Eugene Mironov 99eea2ae03 Add more tests 2025-11-15 00:09:01 +07:00
Eugene Mironov ac33f20e51 Small fixes 2025-11-15 00:09:01 +07:00
Eugene Mironov ab0a9c3d7a Add workable flow 2025-11-15 00:09:01 +07:00
Eugene Mironov 9616c44024 fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 60b432b0f1 fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 513e6c0046 fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 60362b9c7c fixup! fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 5915649eac fixup! Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov 675880392d Turn off compilation for pi0/pi05 2025-11-15 00:09:01 +07:00
Eugene Mironov d0123c4178 fixup! Pi0 eval dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov e86afc883e Pi0 eval dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov d10b7787eb Pi0 2025-11-15 00:09:01 +07:00
Eugene Mironov ac1816ee9c Add RTC to PI0 2025-11-15 00:09:01 +07:00
Eugene Mironov 25fb16ea7a Fix compilation 2025-11-15 00:09:01 +07:00
Eugene Mironov 7baf909e32 Debug 2025-11-15 00:09:01 +07:00
Eugene Mironov 79ffe316e4 Experiemnt with late detach 2025-11-15 00:09:01 +07:00
Eugene Mironov 68b2142bd2 fixup! Add matplotliv to dev 2025-11-15 00:09:01 +07:00
Eugene Mironov a42fb4d0e2 Add matplotliv to dev 2025-11-15 00:09:01 +07:00
Eugene Mironov 83f1de035e delete policies 2025-11-15 00:09:01 +07:00
Eugene Mironov e09a6a90e1 Add torch compilation for eval_dataset 2025-11-15 00:09:01 +07:00
Eugene Mironov 10cc9dd961 Drop not required methods 2025-11-15 00:09:01 +07:00
Eugene Mironov 41b8d4b7c6 Fix tests 2025-11-15 00:09:01 +07:00
Eugene Mironov 7939fc3ddf Add tests for tracker 2025-11-15 00:09:01 +07:00
Eugene Mironov 11b35dfa11 Right kwargs for the policy 2025-11-15 00:09:01 +07:00
Eugene Mironov b27570039c Fix traacking 2025-11-15 00:09:01 +07:00
Eugene Mironov 55c4cc1b27 fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov 3fb3edde3f fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov 43bf1fb763 fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-15 00:09:01 +07:00
Eugene Mironov c7a26f5070 Improve visualization: separate correction plot and fix axis scaling
Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov aaa308b158 fixup! Refactor plotting loging 2025-11-15 00:09:01 +07:00
Eugene Mironov 84df6cd13d Refactor plotting loging 2025-11-15 00:09:01 +07:00
Eugene Mironov 26db4b64d8 Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 2204a45020 Refactor SmolVLA plotting to use tracker data instead of local variables
Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov b6df884d08 Fix logging buffering and enable tracking when RTC config provided
- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov bb23dafad1 fixup! Use output_dir for saving all evaluation images 2025-11-15 00:09:01 +07:00
Eugene Mironov c409ed2d1d Use output_dir for saving all evaluation images
Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov d20ef2e46e Rename track_debug method to track
Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 05189361b6 Refactor RTC enabled checks to use _rtc_enabled helper
Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 896779003c Add RTCConfig field to SmolVLAConfig
Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov b55bc62ef0 fixup! Fix rtc_config attribute access in SmolVLA 2025-11-15 00:09:01 +07:00
Eugene Mironov 08ff689a1e Fix rtc_config attribute access in SmolVLA
Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
Eugene Mironov 0acdde4ae2 Add Real-Time Chunking (RTC) support for flow matching models
Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-15 00:09:01 +07:00
35 changed files with 9237 additions and 49 deletions
+263
View File
@@ -0,0 +1,263 @@
# RTC Profiling Guide
This guide explains how to profile RTC (Real-Time Chunking) performance to identify bottlenecks and understand why RTC might be slower than expected.
## Quick Start
### 1. Profile with Real Robot (Profiled Version)
Use `eval_with_real_robot_profiled.py` to profile actual robot execution:
```bash
# With RTC enabled
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=30
# Without RTC for comparison
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=30
```
**Output**: At the end of execution, you'll see a detailed breakdown of timing for each component:
- `get_actions.policy_inference` - Time spent in policy inference
- `get_actions.preprocessing` - Time spent preprocessing observations
- `get_actions.postprocessing` - Time spent postprocessing actions
- `get_actions.action_queue_merge` - Time spent merging actions with RTC
- `robot.get_observation` - Time to get observations from robot
- `robot.send_action` - Time to send actions to robot
- And more...
### 2. Profile Without Robot (Comparison Script)
Use `profile_rtc_comparison.py` to profile just the policy inference without needing a robot:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
**Output**: Side-by-side comparison of performance with and without RTC, including:
- Mean/min/max inference times
- Throughput (iterations per second)
- Verdict on whether RTC is faster or slower
### 3. Enable Detailed Method-Level Profiling
For even more granular profiling, add the `--enable_detailed_profiling` flag:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20 \
--enable_detailed_profiling
```
This will show timing for individual methods within the policy.
## Understanding the Output
### Key Metrics to Look At
1. **get_actions.policy_inference** - This should be the largest component
- If RTC is enabled, this includes the RTC guidance overhead
- Compare this with/without RTC to see the overhead
2. **get_actions.preprocessing** - Image preprocessing and normalization
- Should be relatively fast
- If slow, consider optimizing image processing
3. **get_actions.postprocessing** - Action denormalization
- Should be minimal
- If slow, check postprocessor implementation
4. **get_actions.action_queue_merge** - RTC-specific merging logic
- Only present when RTC is enabled
- If this is taking significant time, the RTC algorithm may need optimization
5. **robot.get_observation** - Robot communication overhead
- If slow, check camera/sensor latency
- Consider reducing image resolution
6. **robot.send_action** - Action execution overhead
- Should be very fast
- If slow, check robot communication
### Expected Performance
For a typical Pi0 policy on Apple Silicon (MPS):
- **Without RTC**: ~100-200ms per inference
- **With RTC**: Should be similar or slightly faster due to action reuse
- **Preprocessing**: ~5-20ms depending on number of cameras
- **Postprocessing**: ~1-5ms
If RTC is significantly slower, likely causes:
1. **RTC overhead exceeds benefits** - The guidance computation is expensive
2. **Execution horizon too small** - Not reusing enough actions to amortize overhead
3. **No compilation** - Try with `--use_torch_compile`
4. **Large prev_actions buffer** - Copying/processing previous actions is slow
## Profiling Your Own Code
### Using the Profiling Decorator
Add profiling to your own methods:
```python
from lerobot.utils.profiling import profile_method, enable_profiling, print_profiling_summary
# Enable profiling
enable_profiling()
# Decorate methods you want to profile
@profile_method
def my_slow_function(x):
# ... your code ...
return result
# At end of execution
print_profiling_summary()
```
### Using Profile Context Manager
For profiling specific code blocks:
```python
from lerobot.utils.profiling import profile_section, enable_profiling
enable_profiling()
with profile_section("data_loading"):
data = load_data()
with profile_section("model_inference"):
output = model(data)
```
### Adding Profiling to Policy Methods
To profile specific parts of the Pi0 policy, you can add decorators:
```python
# In src/lerobot/policies/pi0/modeling_pi0.py
from lerobot.utils.profiling import profile_method, profile_section
class Pi0Policy:
@profile_method
def predict_action_chunk(self, obs, inference_delay=0, prev_chunk_left_over=None):
# ... existing code ...
pass
def _generate_actions_with_rtc(self, ...):
with profile_section("rtc.guidance_computation"):
# ... guidance code ...
pass
with profile_section("rtc.action_merging"):
# ... merging code ...
pass
```
## Analyzing Results
### Comparison Checklist
When comparing RTC vs non-RTC performance, check:
- [ ] Is `policy_inference` time higher with RTC?
- [ ] Is `action_queue_merge` taking significant time?
- [ ] Are you running enough iterations to amortize warmup?
- [ ] Is torch.compile enabled for fair comparison?
- [ ] Is the execution horizon large enough? (should be >= 10-20)
- [ ] Are you testing on the same hardware/device?
### Common Bottlenecks
1. **Image preprocessing dominates**
- Solution: Reduce image resolution, use fewer cameras, or optimize preprocessing
2. **Action queue operations are slow**
- Solution: Review queue implementation, consider using ring buffer
3. **RTC guidance is expensive**
- Solution: Reduce guidance weight, simplify guidance computation, use torch.compile
4. **Robot communication is slow**
- Solution: Increase baud rate, reduce action frequency, optimize protocol
5. **Memory allocation overhead**
- Solution: Pre-allocate buffers, reuse tensors, avoid unnecessary copies
## Advanced: Adding Custom Metrics
You can add custom timing metrics to the profiled script:
```python
from lerobot.utils.profiling import record_timing
start = time.perf_counter()
# ... your code ...
duration = time.perf_counter() - start
record_timing("my_custom_metric", duration)
```
## Troubleshooting
### Profiling shows RTC is slower by >50%
1. Check if torch.compile is enabled: `--use_torch_compile`
2. Increase execution horizon: `--rtc.execution_horizon=30`
3. Verify inference_delay is calculated correctly
4. Profile with `--enable_detailed_profiling` to find exact bottleneck
### Profiling output is empty
1. Make sure profiling is enabled with `enable_profiling()`
2. Verify you're running enough iterations (at least 10)
3. Check that code is actually executing (not short-circuited)
### Inconsistent results between runs
1. Run more iterations: `--num_iterations=100`
2. Increase warmup iterations
3. Check for thermal throttling on device
4. Ensure no other processes competing for resources
## Next Steps
1. Run both profiling scripts (with/without robot)
2. Compare timing breakdowns
3. Identify the largest bottleneck
4. Focus optimization efforts on that component
5. Re-run profiling to verify improvements
## Questions?
If profiling reveals unexpected bottlenecks or you need help interpreting results, please share:
- The full profiling output
- Your configuration (RTC enabled/disabled, execution horizon, etc.)
- Hardware specs (device type, memory, etc.)
- Policy type and size
+208
View File
@@ -0,0 +1,208 @@
# RTC Profiling - Quick Start
Quick reference for profiling Pi0 with RTC to identify performance bottlenecks.
## 🚀 Quick Commands
### 1. Profile with Real Robot
```bash
# With RTC enabled (profiled version)
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0}, front: {type: opencv, index_or_path: 1}}" \
--task="Pick up object" \
--duration=30
```
### 2. Compare RTC vs No-RTC (No Robot Needed)
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
### 3. Detailed RTC Method Profiling
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
```
## 📊 What Each Tool Does
| Tool | Purpose | Needs Robot? |
|------|---------|--------------|
| `eval_with_real_robot_profiled.py` | Profile actual robot execution with RTC | ✅ Yes |
| `profile_rtc_comparison.py` | Compare RTC vs no-RTC side-by-side | ❌ No |
| `profile_pi0_rtc_detailed.py` | Deep dive into RTC internals | ❌ No |
## 🔍 Key Metrics to Watch
### Overall Performance
- **iteration.policy_inference** - Total policy inference time
- **iteration.preprocessing** - Image preprocessing time
- **iteration.postprocessing** - Action denormalization time
### RTC-Specific (with `--enable_rtc_profiling`)
- **rtc.denoise_step.base_denoising** - Time without RTC overhead
- **rtc.denoise_step.autograd_correction** - Gradient computation time
- **rtc.denoise_step.guidance_computation** - Total RTC guidance overhead
### Robot Communication
- **robot.get_observation** - Time to get robot state
- **robot.send_action** - Time to send action command
## 🎯 Quick Diagnosis
### RTC is slower than expected?
1. **Check if torch.compile is enabled**
```bash
# Add this flag
--use_torch_compile
```
2. **Try larger execution horizon**
```bash
# Increase to amortize RTC overhead
--rtc.execution_horizon=30
```
3. **Profile to find bottleneck**
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--enable_rtc_profiling
```
### Preprocessing is slow?
- Reduce image resolution in robot config
- Use fewer cameras
- Check camera FPS settings
### Policy inference is slow?
- Enable torch.compile
- Check device (MPS vs CUDA vs CPU)
- Try smaller model if available
## 📈 Expected Performance
### Typical timings on Apple Silicon (MPS):
| Component | Time (ms) | Notes |
|-----------|-----------|-------|
| Policy inference | 100-200 | Depends on model size |
| Preprocessing | 5-20 | Depends on #cameras |
| Postprocessing | 1-5 | Usually fast |
| RTC overhead | 10-50 | Should be < 50% of base |
### When RTC helps:
- ✅ Execution horizon ≥ 10
- ✅ Inference time > action execution rate
- ✅ Using torch.compile
- ✅ Proper inference_delay calculation
### When RTC might not help:
- ❌ Very fast inference already
- ❌ Small execution horizon (< 5)
- ❌ No compilation (interpreted mode)
- ❌ Inference delay not accounted for
## 🛠️ Adding Profiling to Your Code
### Quick snippet:
```python
from lerobot.utils.profiling import enable_profiling, print_profiling_summary, profile_section
# Enable at start
enable_profiling()
# Profile sections
with profile_section("my_operation"):
# ... your code ...
pass
# Print at end
print_profiling_summary()
```
### Profile specific methods:
```python
from lerobot.utils.profiling import profile_method
@profile_method
def my_slow_function():
# ... your code ...
pass
```
## 📝 Example Output
```
PROFILING SUMMARY
================================================================================
Function Count Mean (ms)
--------------------------------------------------------------------------------
iteration.policy_inference 20 150.23
iteration.preprocessing 20 12.45
rtc.denoise_step.guidance_computation 200 15.67
rtc.denoise_step.autograd_correction 200 8.23
rtc.denoise_step.base_denoising 200 120.45
================================================================================
```
## 🚨 Common Issues
### "No profiling data available"
- Did you call `enable_profiling()`?
- Running enough iterations?
### Inconsistent results
- Increase `--num_iterations`
- Check for thermal throttling
- Close other applications
### Can't find bottleneck
- Enable `--enable_rtc_profiling` for detailed breakdown
- Check both preprocessing and inference
- Compare with and without RTC
## 📖 More Details
See `PROFILING_GUIDE.md` for comprehensive documentation.
## 🤔 Still Slow?
1. Run comparison: `profile_rtc_comparison.py`
2. Run detailed profiling: `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
3. Share output for help (include device, model, settings)
## ✅ Quick Checklist
Before asking for help, verify:
- [ ] Ran comparison script (with/without RTC)
- [ ] Tried torch.compile
- [ ] Tested different execution horizons (10, 20, 30)
- [ ] Profiled with detailed RTC profiling
- [ ] Checked preprocessing vs inference split
- [ ] Verified hardware (device type, thermal state)
+352
View File
@@ -0,0 +1,352 @@
# RTC Profiling Toolkit
Complete toolkit for profiling Pi0 with RTC to identify performance bottlenecks.
## 📦 What's Included
### Scripts
1. **`eval_with_real_robot_profiled.py`**
- Profiled version of the real robot eval script
- Adds timing measurements throughout execution
- Works with actual robot hardware
- Same usage as original but with profiling output
2. **`profile_rtc_comparison.py`**
- Side-by-side comparison of RTC vs no-RTC
- No robot needed (uses mock observations)
- Shows clear verdict on whether RTC is helping
- Great for quick performance checks
3. **`profile_pi0_rtc_detailed.py`**
- Most detailed profiling available
- Can enable RTC method-level profiling
- Provides insights and recommendations
- Perfect for deep-dive investigations
4. **`add_rtc_profiling.py`**
- Monkey-patching utility for RTC internals
- Profiles individual RTC operations
- Can be applied without modifying source
- Shows exactly where RTC spends time
### Utilities
5. **`src/lerobot/utils/profiling.py`**
- Core profiling utilities
- Decorators for method profiling
- Context managers for code blocks
- Statistics collection and reporting
### Documentation
6. **`PROFILING_GUIDE.md`** - Comprehensive guide
7. **`PROFILING_QUICK_START.md`** - Quick reference
## 🚀 Quick Start
### Step 1: Compare Performance
Run this first to see if RTC is actually slower:
```bash
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
```
**Expected output:**
```
COMPARISON SUMMARY
================================================================================
Metric Without RTC With RTC Difference
--------------------------------------------------------------------------------
Mean time (ms) 150.23 165.45 +15.22
Throughput (iter/s) 6.66 6.05 -0.61
================================================================================
VERDICT
✗ RTC is SLOWER by 10.1%
Mean time increased by 15.22 ms
Possible reasons:
- RTC overhead exceeds benefits at current execution horizon
- No torch.compile enabled
```
### Step 2: Identify Bottleneck
If RTC is slower, find out why:
```bash
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
```
**Expected output:**
```
PROFILING SUMMARY
================================================================================
Function Count Mean (ms) Total (s)
------------------------------------------------------------------------------------
iteration.policy_inference 20 150.23 3.00
rtc.denoise_step.guidance_computation 200 15.67 3.13
rtc.denoise_step.autograd_correction 200 8.23 1.65
iteration.preprocessing 20 12.45 0.25
================================================================================
KEY INSIGHTS
================================================================================
Time breakdown:
Policy inference: 150.23 ms (87.2%)
Preprocessing: 12.45 ms (7.2%)
Postprocessing: 2.10 ms (1.2%)
RTC breakdown:
Base denoising: 120.45 ms
Guidance compute: 15.67 ms
Autograd correct: 8.23 ms
RTC overhead: 23.90 ms (19.8% of base)
Recommendations:
⚠ RTC autograd overhead is significant
→ This is expected, but consider increasing execution_horizon
→ Try torch.compile if not already enabled
💡 torch.compile not enabled
→ Try --use_torch_compile for potential speedup
================================================================================
```
### Step 3: Try Optimizations
Based on recommendations:
```bash
# Try with torch.compile
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20 \
--use_torch_compile
# Try larger execution horizon
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=30
```
### Step 4: Profile Real Robot (Optional)
Test with actual hardware:
```bash
uv run examples/rtc/eval_with_real_robot_profiled.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{...}" \
--task="Pick up object" \
--duration=30
```
## 🎯 Common Scenarios
### "RTC is 2x slower!"
This usually means:
- RTC overhead is high but not getting benefits
- Need to enable torch.compile
- Execution horizon too small
- Inference delay not calculated correctly
**Try:**
1. `--use_torch_compile`
2. Increase `--execution_horizon` to 30+
3. Check inference_delay calculation
### "RTC is only slightly slower"
This is expected! RTC overhead is about 10-30% typically.
The benefit comes during **execution**, not single inference:
- Actions are reused across chunks
- Overall system latency is reduced
- Robot gets smoother actions
### "Want to optimize specific part"
Use the profiling utilities:
```python
from lerobot.utils.profiling import enable_profiling, profile_section, print_profiling_summary
enable_profiling()
with profile_section("my_custom_operation"):
# Your code here
pass
print_profiling_summary()
```
## 📊 Understanding Results
### Key Metrics
**Policy Inference Time**
- Time for forward pass through model
- Should be largest component (70-90%)
- Includes RTC guidance if enabled
**Preprocessing Time**
- Image normalization, resizing
- Should be < 20% of total
- If high: reduce image resolution
**RTC Guidance Overhead**
- Extra time for RTC guidance computation
- Typically 10-30% of base inference
- If > 50%: RTC may not be beneficial at current settings
**Autograd Correction**
- Time computing gradients for RTC
- Usually 5-15% of base inference
- Can be reduced with torch.compile
### Expected Ranges (Apple Silicon MPS)
| Metric | Good | Acceptable | Poor |
|--------|------|------------|------|
| Policy inference | 100-150ms | 150-250ms | >250ms |
| Preprocessing | <20ms | 20-50ms | >50ms |
| RTC overhead | 10-30% | 30-50% | >50% |
## 🔧 Optimization Guide
### If RTC overhead is too high:
1. **Enable compilation:**
```bash
--use_torch_compile
```
Expected improvement: 20-40% faster
2. **Increase execution horizon:**
```bash
--execution_horizon=30 # or higher
```
Amortizes RTC cost over more actions
3. **Check guidance weight:**
```python
# In config
rtc.max_guidance_weight=1.0 # try 0.5 for less overhead
```
### If preprocessing is slow:
1. **Reduce image resolution:**
```python
# In robot config
cameras={
"gripper": {"width": 320, "height": 240} # instead of 640x480
}
```
2. **Use fewer cameras:**
- Profile which cameras are essential
- Remove unnecessary views
### If inference is generally slow:
1. Use torch.compile (if not already)
2. Check device is correct (MPS vs CUDA)
3. Verify model is in eval mode
4. Check for unnecessary gradient tracking
## 🐛 Troubleshooting
### Empty profiling output
```python
# Make sure to enable profiling!
from lerobot.utils.profiling import enable_profiling
enable_profiling()
```
### Inconsistent timings
- Run more iterations (50-100)
- Check thermal throttling
- Close background apps
- Use `--warmup_iterations=10`
### Can't find bottleneck
1. Start with `profile_rtc_comparison.py`
2. Then run `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
3. Compare with/without RTC
4. Check each component separately
## 📖 Full Documentation
- **`PROFILING_GUIDE.md`** - Complete reference with examples
- **`PROFILING_QUICK_START.md`** - Quick commands and tips
## 🤝 Getting Help
If you're still experiencing issues:
1. Run comparison script and save output
2. Run detailed profiling and save output
3. Include:
- Policy path
- Device type
- RTC settings (execution_horizon, etc.)
- Hardware specs
- Full profiling output
## 🎓 Learning More
### Profiling your own code:
```python
from lerobot.utils.profiling import profile_method, enable_profiling
enable_profiling()
@profile_method
def my_function():
# Automatically profiled
pass
```
### RTC internals:
```python
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
enable_profiling()
monkey_patch_rtc_profiling()
# Now RTC methods are profiled
policy.predict_action_chunk(...)
```
## ✨ Next Steps
1. Run `profile_rtc_comparison.py` to establish baseline
2. Use `profile_pi0_rtc_detailed.py` to find bottlenecks
3. Apply optimizations (torch.compile, larger horizon)
4. Re-run comparison to verify improvements
5. Test with real robot using profiled version
Happy profiling! 🚀
+251
View File
@@ -0,0 +1,251 @@
# Real-Time Chunking (RTC) Examples
This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control.
## Overview
Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions.
**Key Benefits:**
- Maintains consistency between consecutive action chunks
- Reduces jitter and improves smoothness
- Adapts to inference delays dynamically
**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
## Scripts
### 1. `eval_dataset.py`
Offline evaluation on dataset samples with detailed visualization and validation.
**Features:**
- Compare RTC vs non-RTC predictions on two random dataset samples
- Validate RTC behavior (delay region, blend region, post-horizon region)
- Generate debug visualizations:
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
- Final action predictions comparison
- Support for torch.compile() optimization
- Memory-efficient sequential policy loading for large models
**Usage:**
```bash
# Basic usage with SmolVLA policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--seed=10
# With Pi0.5 policy on CUDA
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# With Pi0 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
```
**Key Parameters:**
- `--policy.path`: Path to pretrained policy
- `--dataset.repo_id`: Dataset to evaluate on
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--inference_delay`: Inference delay for RTC (default: 4)
- `--seed`: Random seed for reproducibility (default: 42)
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
- `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
**Output:**
The script generates several visualization files in `rtc_debug_output/`:
- `denoising_xt_comparison.png` - Noisy state evolution during denoising
- `denoising_vt_comparison.png` - Velocity predictions during denoising
- `denoising_x1t_comparison.png` - Predicted final states during denoising
- `denoising_correction_comparison.png` - RTC guidance corrections applied
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
The script also validates RTC behavior and reports:
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
### 2. `eval_with_real_robot.py`
Real-time evaluation on physical robots or simulation environments.
**Features:**
- Run policy with RTC on real robot or simulation
- Multi-threaded action execution and inference
- Action queue management with proper timing
- Latency tracking and adaptive inference delay
- Support for both robots and gym environments
- Support for torch.compile() optimization
**Usage:**
```bash
# With real robot
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--task="pick up the cup" \
--duration=30.0
# With simulation environment
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--env.type=pusht \
--duration=60.0
# With policy compilation (CUDA only, not MPS)
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
```
**Key Parameters:**
- `--policy.path`: Path to pretrained policy
- `--robot.type` or `--env.type`: Robot or environment to use
- `--task`: Task description (for VLA models)
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--duration`: How long to run (seconds, default: 30.0)
- `--fps`: Action execution frequency (Hz, default: 10.0)
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
- `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
## Understanding RTC Parameters
### `execution_horizon`
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
### `max_guidance_weight`
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
**Typical values:**
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
- Real-time execution: 1.0-10.0 (more conservative)
### `prefix_attention_schedule`
How to weight consistency across the overlap region:
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
- `ONES`: Full weight across entire execution_horizon
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended)
**Recommended:** `EXP`
### `inference_delay`
Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
**Typical values:** 3-5 steps for dataset evaluation
### `action_queue_size_to_get_new_actions` (real-time only)
Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
**Typical values:** 20-30 steps
## Validation Rules (Dataset Evaluation)
The dataset evaluation script validates that RTC behavior matches expectations:
1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
- Ensures consistency during the inference delay period
2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
- Smooth transition from previous plan to new predictions
3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
- Full adoption of new predictions after execution horizon
## Tips
1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
3. **Tune execution_horizon** based on your inference latency and action frequency
4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
## Troubleshooting
### Validation fails in delay region
- Check that `prev_chunk_left_over` is properly passed to the policy
- Verify RTC guidance is being applied during denoising
- Look at denoising visualizations to see where guidance diverges
### Validation fails in post-horizon region
- RTC and no_rtc use different noise - verify same noise is being used for comparison
- Check that weights are correctly zeroed out after execution horizon
- Review prefix_attention_schedule visualization
### Poor performance on real robot
- Increase `action_queue_size_to_get_new_actions` if you see warnings
- Reduce `max_guidance_weight` if robot is too conservative
- Try different `prefix_attention_schedule` values
- Enable torch.compile() for faster inference (CUDA only)
### Memory issues with large models
- The dataset evaluation script loads policies sequentially to minimize memory
- For real-time execution, only one policy is loaded
- Use smaller batch sizes if needed
## Related Documentation
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
+202
View File
@@ -0,0 +1,202 @@
#!/usr/bin/env python
"""
Script to add profiling instrumentation to RTCProcessor.
This script shows which methods to profile in the RTC code to identify bottlenecks.
You can either:
1. Apply these changes directly to modeling_rtc.py
2. Use monkey patching to add profiling without modifying source
3. Use as reference for manual instrumentation
Usage:
# Option 1: Monkey patch (no source changes)
python examples/rtc/add_rtc_profiling.py
# Option 2: Apply changes to source
# Copy the profiled methods below into src/lerobot/policies/rtc/modeling_rtc.py
"""
import logging
import torch
from torch import Tensor
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.profiling import ProfileContext, enable_profiling, is_profiling_enabled
logger = logging.getLogger(__name__)
def profile_denoise_step(self, x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon=None) -> Tensor:
"""Profiled version of denoise_step."""
if not is_profiling_enabled():
# Call original implementation if profiling disabled
return self._original_denoise_step(x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon)
with ProfileContext("rtc.denoise_step.total"):
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
with ProfileContext("rtc.denoise_step.base_denoising"):
v_t = original_denoise_step_partial(x_t)
return v_t
with ProfileContext("rtc.denoise_step.setup"):
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
# Padding
with ProfileContext("rtc.denoise_step.padding"):
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
# Get prefix weights
with ProfileContext("rtc.denoise_step.get_prefix_weights"):
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
# Main RTC guidance computation
with ProfileContext("rtc.denoise_step.guidance_computation"):
with torch.enable_grad():
# Base denoising
with ProfileContext("rtc.denoise_step.base_denoising"):
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
# Compute x1_t
with ProfileContext("rtc.denoise_step.compute_x1_t"):
x1_t = x_t - time * v_t
# Compute error
with ProfileContext("rtc.denoise_step.compute_error"):
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
# Compute correction via autograd
with ProfileContext("rtc.denoise_step.autograd_correction"):
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
# Compute guidance weight
with ProfileContext("rtc.denoise_step.compute_guidance_weight"):
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
# Apply guidance
with ProfileContext("rtc.denoise_step.apply_guidance"):
result = v_t - guidance_weight * correction
# Cleanup
with ProfileContext("rtc.denoise_step.cleanup"):
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def monkey_patch_rtc_profiling():
"""Apply profiling to RTCProcessor via monkey patching.
This modifies the RTCProcessor class at runtime to add profiling
without changing source files.
"""
logger.info("Applying RTC profiling monkey patch...")
# Save original method
RTCProcessor._original_denoise_step = RTCProcessor.denoise_step
# Replace with profiled version
RTCProcessor.denoise_step = profile_denoise_step
logger.info("✓ RTC profiling enabled")
def print_usage():
"""Print usage instructions."""
print("\n" + "="*80)
print("RTC PROFILING INSTRUMENTATION")
print("="*80)
print("\nThis script provides profiling for RTCProcessor methods.")
print("\nOption 1: Monkey Patch (Recommended)")
print("-" * 40)
print("Add to your script:")
print("""
from lerobot.utils.profiling import enable_profiling, print_profiling_summary
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
# Enable profiling
enable_profiling()
monkey_patch_rtc_profiling()
# ... run your code ...
# Print results
print_profiling_summary()
""")
print("\nOption 2: Manual Source Modification")
print("-" * 40)
print("1. Copy profile_denoise_step() from this file")
print("2. Replace denoise_step() in src/lerobot/policies/rtc/modeling_rtc.py")
print("3. Add profiling imports at top of file")
print("\nKey Metrics to Watch:")
print("-" * 40)
print("- rtc.denoise_step.base_denoising - Time for base policy inference")
print("- rtc.denoise_step.autograd_correction - Time computing gradients")
print("- rtc.denoise_step.guidance_computation - Total guidance overhead")
print("- rtc.denoise_step.get_prefix_weights - Time computing weights")
print("="*80 + "\n")
if __name__ == "__main__":
print_usage()
File diff suppressed because it is too large Load Diff
+549
View File
@@ -0,0 +1,549 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
@@ -0,0 +1,631 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Profiled version of eval_with_real_robot.py for performance analysis.
This version adds detailed timing measurements for:
- Policy inference
- Preprocessing
- Postprocessing
- Action queue operations
- Robot communication
- Thread execution times
Usage: Same as eval_with_real_robot.py but with profiling output.
"""
import logging
import math
import sys
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ProfileTimer:
"""Context manager and utility class for timing code sections."""
def __init__(self, name: str, stats_dict: dict):
self.name = name
self.stats_dict = stats_dict
self.start_time = None
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, *args):
elapsed = time.perf_counter() - self.start_time
if self.name not in self.stats_dict:
self.stats_dict[self.name] = []
self.stats_dict[self.name].append(elapsed)
class ProfilingStats:
"""Global profiling statistics collector."""
def __init__(self):
self.stats = defaultdict(list)
self.lock = Lock()
def record(self, name: str, duration: float):
with self.lock:
self.stats[name].append(duration)
def timer(self, name: str):
"""Return a context manager for timing."""
return ProfileTimer(name, self.stats)
def get_summary(self) -> dict[str, dict[str, float]]:
"""Get summary statistics for all timings."""
with self.lock:
summary = {}
for name, times in self.stats.items():
if times:
summary[name] = {
"count": len(times),
"mean": sum(times) / len(times),
"min": min(times),
"max": max(times),
"total": sum(times),
}
return summary
def print_summary(self):
"""Print formatted summary of all timings."""
summary = self.get_summary()
logger.info("\n" + "=" * 80)
logger.info("PROFILING SUMMARY")
logger.info("=" * 80)
# Sort by total time (descending)
sorted_items = sorted(summary.items(), key=lambda x: x[1]["total"], reverse=True)
for name, stats in sorted_items:
logger.info(f"\n{name}:")
logger.info(f" Count: {stats['count']}")
logger.info(f" Mean: {stats['mean']*1000:.2f} ms")
logger.info(f" Min: {stats['min']*1000:.2f} ms")
logger.info(f" Max: {stats['max']*1000:.2f} ms")
logger.info(f" Total: {stats['total']:.2f} s")
logger.info(f" Hz: {stats['count']/stats['total']:.2f}")
logger.info("\n" + "=" * 80)
# Global profiling stats
profiling_stats = ProfilingStats()
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with profiling_stats.timer("robot.get_observation"):
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with profiling_stats.timer("robot.send_action"):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy with profiling.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
inference_count = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
with profiling_stats.timer("get_actions.total_iteration"):
inference_count += 1
logger.info(f"[GET_ACTIONS] Starting inference #{inference_count}")
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
with profiling_stats.timer("get_actions.get_prev_actions"):
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
# Get observation
obs = robot.get_observation()
# Apply robot observation processor
with profiling_stats.timer("get_actions.robot_obs_processing"):
obs_processed = robot_observation_processor(obs)
# Build dataset frame
with profiling_stats.timer("get_actions.build_dataset_frame"):
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
# Convert to tensors and normalize
with profiling_stats.timer("get_actions.tensor_conversion"):
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task]
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
# Preprocessing
with profiling_stats.timer("get_actions.preprocessing"):
preproceseded_obs = preprocessor(obs_with_policy_features)
# Policy inference
with profiling_stats.timer("get_actions.policy_inference"):
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Clone for RTC
with profiling_stats.timer("get_actions.clone_actions"):
original_actions = actions.squeeze(0).clone()
# Postprocessing
with profiling_stats.timer("get_actions.postprocessing"):
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
# Update latency tracker
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
logger.info(
f"[GET_ACTIONS] Inference #{inference_count} completed in {new_latency*1000:.2f}ms "
f"(delay={new_delay} chunks)"
)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, "
"It should be higher than inference delay + execution horizon."
)
# Merge into action queue
with profiling_stats.timer("get_actions.action_queue_merge"):
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot with profiling.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
with profiling_stats.timer("actor.total_iteration"):
# Get action from queue
with profiling_stats.timer("actor.queue_get"):
action = action_queue.get()
if action is not None:
# Process action
with profiling_stats.timer("actor.action_processing"):
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
# Send to robot (includes robot.send_action timing)
robot.send_action(action_processed)
action_count += 1
# Sleep to maintain target FPS
dt_s = time.perf_counter() - start_time
sleep_time = max(0, (action_interval - dt_s) - 0.001)
if sleep_time > 0:
time.sleep(sleep_time)
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with profiling."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
logger.info("=" * 80)
logger.info("PROFILING MODE ENABLED")
logger.info("=" * 80)
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processor
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
# Print profiling summary
profiling_stats.print_summary()
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+358
View File
@@ -0,0 +1,358 @@
#!/usr/bin/env python
"""
Comprehensive profiling script for Pi0 with RTC.
This script demonstrates how to use all the profiling tools to identify
bottlenecks in Pi0 policy inference with RTC enabled.
It profiles:
1. Overall inference time
2. RTC-specific operations (guidance, weights, etc.)
3. Preprocessing/postprocessing
4. Individual method timings
Usage:
uv run examples/rtc/profile_pi0_rtc_detailed.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=20 \
--execution_horizon=20 \
--enable_rtc_profiling
"""
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.profiling import (
ProfileContext,
clear_profiling_stats,
enable_profiling,
get_profiling_stats,
print_profiling_summary,
)
# Import monkey patching for RTC profiling
try:
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
except ImportError:
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
monkey_patch_rtc_profiling = None
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_mock_observation(policy_config, device: str) -> dict:
"""Create a mock observation matching policy requirements.
Args:
policy_config: Policy configuration
device: Device to create tensors on
Returns:
Mock observation dictionary
"""
obs = {}
# Create mock state observation
state_dim = 10 # Typical robot state dimension
obs["observation.state"] = torch.randn(1, state_dim, device=device)
# Create mock images if needed
# For Pi0, we typically need at least one image
image_height = 224
image_width = 224
# Common image keys for Pi0
image_keys = ["observation.images.gripper", "observation.images.front"]
for key in image_keys:
# Images should be [B, C, H, W] and normalized to [0, 1]
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
# Add task
obs["task"] = ["Pick up the object"]
# Add language tokens and attention mask (required for Pi0)
# These are mock values - in real usage they come from tokenizer
max_seq_len = 32
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
return obs
def profile_single_iteration(
policy,
preprocessor,
postprocessor,
observation: dict,
prev_actions: torch.Tensor | None,
use_rtc: bool,
inference_delay: int = 0,
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
"""Profile a single inference iteration.
Args:
policy: Policy instance
preprocessor: Observation preprocessor
postprocessor: Action postprocessor
observation: Input observation
prev_actions: Previous action chunk (for RTC)
use_rtc: Whether RTC is enabled
inference_delay: Inference delay in timesteps
Returns:
Tuple of (actions, new_prev_actions, timings)
"""
timings = {}
with ProfileContext("iteration.total"):
# Preprocessing
with ProfileContext("iteration.preprocessing"):
preprocessed_obs = preprocessor(observation)
# Policy inference
with ProfileContext("iteration.policy_inference"):
if use_rtc:
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
else:
actions = policy.predict_action_chunk(preprocessed_obs)
# Clone for next iteration (if RTC)
new_prev_actions = None
if use_rtc:
with ProfileContext("iteration.prepare_prev_actions"):
execution_horizon = policy.config.rtc_config.execution_horizon
if actions.shape[1] > execution_horizon:
new_prev_actions = actions[:, execution_horizon:].clone()
# Postprocessing
with ProfileContext("iteration.postprocessing"):
processed_actions = postprocessor(actions)
return processed_actions, new_prev_actions, timings
def main():
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
args = parser.parse_args()
logger.info("="*80)
logger.info("DETAILED PI0 RTC PROFILING")
logger.info("="*80)
logger.info(f"Policy: {args.policy_path}")
logger.info(f"Device: {args.device}")
logger.info(f"Iterations: {args.num_iterations}")
logger.info(f"Execution Horizon: {args.execution_horizon}")
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
logger.info("="*80 + "\n")
# Enable profiling
enable_profiling()
# Apply RTC profiling if requested
if args.enable_rtc_profiling:
if monkey_patch_rtc_profiling is not None:
monkey_patch_rtc_profiling()
logger.info("✓ Detailed RTC profiling enabled\n")
else:
logger.warning("⚠ Could not enable detailed RTC profiling\n")
# Load policy
logger.info("Loading policy...")
config = PreTrainedConfig.from_pretrained(args.policy_path)
if hasattr(config, "compile_model"):
config.compile_model = args.use_torch_compile
policy_class = get_policy_class(config.type)
policy = policy_class.from_pretrained(args.policy_path, config=config)
# Configure RTC
policy.config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=args.execution_horizon,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
policy.init_rtc_processor()
policy = policy.to(args.device)
policy.eval()
logger.info(f"✓ Policy loaded: {config.type}\n")
# Create preprocessor and postprocessor
logger.info("Loading preprocessor/postprocessor...")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=config,
pretrained_path=args.policy_path,
dataset_stats=None,
preprocessor_overrides={
"device_processor": {"device": args.device},
},
)
logger.info("✓ Preprocessor/postprocessor loaded\n")
# Create mock observation
logger.info("Creating mock observation...")
observation = create_mock_observation(config, args.device)
logger.info("✓ Mock observation created\n")
# Warmup
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
prev_actions = None
for i in range(args.warmup_iterations):
with torch.no_grad():
_, prev_actions, _ = profile_single_iteration(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
observation=observation,
prev_actions=prev_actions,
use_rtc=True,
inference_delay=0,
)
# Clear warmup stats
clear_profiling_stats()
logger.info("✓ Warmup complete\n")
# Profiled run WITH RTC
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
prev_actions = None
iteration_times = []
for i in range(args.num_iterations):
start = time.perf_counter()
with torch.no_grad():
_, prev_actions, _ = profile_single_iteration(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
observation=observation,
prev_actions=prev_actions,
use_rtc=True,
inference_delay=0,
)
# Sync CUDA if needed
if args.device.startswith("cuda"):
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
iteration_times.append(elapsed)
if (i + 1) % 5 == 0:
logger.info(f" Completed {i+1}/{args.num_iterations}")
logger.info("✓ Profiling complete\n")
# Print summary statistics
logger.info("\n" + "="*80)
logger.info("ITERATION TIMING SUMMARY")
logger.info("="*80)
times_arr = np.array(iteration_times)
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
logger.info("="*80 + "\n")
# Print detailed profiling breakdown
print_profiling_summary(sort_by="total")
# Print key insights
stats = get_profiling_stats()
logger.info("\n" + "="*80)
logger.info("KEY INSIGHTS")
logger.info("="*80)
# Find bottlenecks
if stats:
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
total_time = policy_inference_time + preprocessing_time + postprocessing_time
if total_time > 0:
logger.info(f"\nTime breakdown:")
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
# RTC-specific insights
if args.enable_rtc_profiling:
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
if rtc_guidance > 0:
logger.info(f"\nRTC breakdown:")
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
# Recommendations
logger.info("\nRecommendations:")
if preprocessing_time > policy_inference_time * 0.3:
logger.info(" ⚠ Preprocessing is taking >30% of time")
logger.info(" → Consider reducing image resolution")
logger.info(" → Consider using fewer cameras")
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
logger.info(" ⚠ RTC autograd overhead is significant")
logger.info(" → This is expected, but consider increasing execution_horizon")
logger.info(" → Try torch.compile if not already enabled")
if not args.use_torch_compile:
logger.info(" 💡 torch.compile not enabled")
logger.info(" → Try --use_torch_compile for potential speedup")
logger.info("="*80 + "\n")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.info("\n\nProfiling interrupted by user")
sys.exit(0)
except Exception as e:
logger.error(f"\n\nError during profiling: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
+347
View File
@@ -0,0 +1,347 @@
#!/usr/bin/env python
"""
Script to compare performance with and without RTC enabled.
This script helps identify whether RTC is actually improving or degrading performance
by running multiple inference passes and collecting detailed timing statistics.
Usage:
# Profile with mock data (no robot needed)
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50
# Profile with specific RTC config
uv run examples/rtc/profile_rtc_comparison.py \
--policy_path=helper2424/pi05_check_rtc \
--device=mps \
--num_iterations=50 \
--execution_horizon=20
"""
import argparse
import logging
import time
from dataclasses import dataclass
import numpy as np
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.profiling import (
clear_profiling_stats,
enable_profiling,
get_profiling_stats,
print_profiling_summary,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ProfileResults:
"""Results from profiling run."""
mode: str # "with_rtc" or "without_rtc"
mean_time: float
std_time: float
min_time: float
max_time: float
times: list[float]
throughput: float # iterations per second
def create_mock_observation(policy, device: str) -> dict:
"""Create a mock observation for testing.
Args:
policy: Policy instance
device: Device to create tensors on
Returns:
Mock observation dictionary
"""
# Get expected input shapes from policy config
# This is a simplified version - adjust based on actual policy requirements
obs = {}
# Mock image observations (if needed)
if hasattr(policy.config, "input_shapes"):
for key, shape in policy.config.input_shapes.items():
if "image" in key:
# Typical image shape: (batch, channels, height, width)
obs[key] = torch.randn(1, *shape, device=device)
else:
obs[key] = torch.randn(1, *shape, device=device)
# Add task if needed
if "task" in policy.config.__dict__ or hasattr(policy, "accepts_task"):
obs["task"] = ["Pick up the object"]
# Mock state observation
obs["observation.state"] = torch.randn(1, 10, device=device) # Adjust size as needed
return obs
def profile_inference(
policy, observation: dict, num_iterations: int, use_rtc: bool, execution_horizon: int = 10
) -> ProfileResults:
"""Profile policy inference with or without RTC.
Args:
policy: Policy instance
observation: Observation dictionary
num_iterations: Number of inference iterations to run
use_rtc: Whether to enable RTC
execution_horizon: Execution horizon for RTC
Returns:
ProfileResults with timing statistics
"""
mode = "with_rtc" if use_rtc else "without_rtc"
logger.info(f"\n{'='*80}")
logger.info(f"Profiling: {mode.upper()}")
logger.info(f"{'='*80}")
# Configure RTC
if use_rtc:
policy.config.rtc_config.enabled = True
policy.config.rtc_config.execution_horizon = execution_horizon
policy.init_rtc_processor()
else:
policy.config.rtc_config.enabled = False
times = []
prev_actions = None
# Warmup
logger.info("Warming up (5 iterations)...")
for _ in range(5):
with torch.no_grad():
if use_rtc:
_ = policy.predict_action_chunk(
observation, inference_delay=0, prev_chunk_left_over=prev_actions
)
else:
_ = policy.predict_action_chunk(observation)
# Actual profiling
logger.info(f"Running {num_iterations} profiled iterations...")
for i in range(num_iterations):
start = time.perf_counter()
with torch.no_grad():
if use_rtc:
actions = policy.predict_action_chunk(
observation, inference_delay=0, prev_chunk_left_over=prev_actions
)
# Simulate consuming some actions for next iteration
if actions.shape[1] > execution_horizon:
prev_actions = actions[:, execution_horizon:].clone()
else:
prev_actions = None
else:
actions = policy.predict_action_chunk(observation)
# Synchronize if using CUDA
if observation["observation.state"].device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
times.append(elapsed)
if (i + 1) % 10 == 0:
logger.info(f" Completed {i+1}/{num_iterations} iterations")
# Calculate statistics
times_arr = np.array(times)
results = ProfileResults(
mode=mode,
mean_time=float(np.mean(times_arr)),
std_time=float(np.std(times_arr)),
min_time=float(np.min(times_arr)),
max_time=float(np.max(times_arr)),
times=times,
throughput=num_iterations / sum(times),
)
logger.info(f"\nResults for {mode}:")
logger.info(f" Mean time: {results.mean_time*1000:.2f} ms")
logger.info(f" Std dev: {results.std_time*1000:.2f} ms")
logger.info(f" Min time: {results.min_time*1000:.2f} ms")
logger.info(f" Max time: {results.max_time*1000:.2f} ms")
logger.info(f" Throughput: {results.throughput:.2f} iter/s")
return results
def compare_results(results_without_rtc: ProfileResults, results_with_rtc: ProfileResults):
"""Compare and print results from both runs.
Args:
results_without_rtc: Results from run without RTC
results_with_rtc: Results from run with RTC
"""
logger.info(f"\n{'='*80}")
logger.info("COMPARISON SUMMARY")
logger.info(f"{'='*80}")
mean_diff = results_with_rtc.mean_time - results_without_rtc.mean_time
mean_diff_pct = (mean_diff / results_without_rtc.mean_time) * 100
throughput_diff = results_with_rtc.throughput - results_without_rtc.throughput
throughput_diff_pct = (throughput_diff / results_without_rtc.throughput) * 100
logger.info(f"\n{'Metric':<30} {'Without RTC':>15} {'With RTC':>15} {'Difference':>15}")
logger.info("-" * 80)
logger.info(
f"{'Mean time (ms)':<30} "
f"{results_without_rtc.mean_time*1000:>15.2f} "
f"{results_with_rtc.mean_time*1000:>15.2f} "
f"{mean_diff*1000:>+15.2f}"
)
logger.info(
f"{'Std dev (ms)':<30} "
f"{results_without_rtc.std_time*1000:>15.2f} "
f"{results_with_rtc.std_time*1000:>15.2f} "
f"{(results_with_rtc.std_time - results_without_rtc.std_time)*1000:>+15.2f}"
)
logger.info(
f"{'Min time (ms)':<30} "
f"{results_without_rtc.min_time*1000:>15.2f} "
f"{results_with_rtc.min_time*1000:>15.2f} "
f"{(results_with_rtc.min_time - results_without_rtc.min_time)*1000:>+15.2f}"
)
logger.info(
f"{'Max time (ms)':<30} "
f"{results_without_rtc.max_time*1000:>15.2f} "
f"{results_with_rtc.max_time*1000:>15.2f} "
f"{(results_with_rtc.max_time - results_without_rtc.max_time)*1000:>+15.2f}"
)
logger.info(
f"{'Throughput (iter/s)':<30} "
f"{results_without_rtc.throughput:>15.2f} "
f"{results_with_rtc.throughput:>15.2f} "
f"{throughput_diff:>+15.2f}"
)
logger.info(f"\n{'='*80}")
logger.info("VERDICT")
logger.info(f"{'='*80}")
if mean_diff_pct < -5:
logger.info(f"✓ RTC is FASTER by {abs(mean_diff_pct):.1f}%")
logger.info(f" Mean time reduced by {abs(mean_diff)*1000:.2f} ms")
elif mean_diff_pct > 5:
logger.info(f"✗ RTC is SLOWER by {mean_diff_pct:.1f}%")
logger.info(f" Mean time increased by {mean_diff*1000:.2f} ms")
logger.info("\n Possible reasons:")
logger.info(" - RTC overhead exceeds benefits at current execution horizon")
logger.info(" - Inference delay calculation not accounting for RTC processing")
logger.info(" - Additional tensor operations in RTC guidance")
else:
logger.info(f"≈ Performance is SIMILAR (difference: {mean_diff_pct:+.1f}%)")
logger.info(f"{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description="Profile RTC performance")
parser.add_argument(
"--policy_path", type=str, required=True, help="Path to pretrained policy"
)
parser.add_argument(
"--device", type=str, default="cuda", help="Device to run on (cuda/cpu/mps)"
)
parser.add_argument(
"--num_iterations", type=int, default=50, help="Number of inference iterations"
)
parser.add_argument(
"--execution_horizon", type=int, default=10, help="RTC execution horizon"
)
parser.add_argument(
"--enable_detailed_profiling",
action="store_true",
help="Enable detailed method-level profiling",
)
parser.add_argument(
"--use_torch_compile", action="store_true", help="Use torch.compile for faster inference"
)
args = parser.parse_args()
# Load policy
logger.info(f"Loading policy from {args.policy_path}")
config = PreTrainedConfig.from_pretrained(args.policy_path)
policy_class = get_policy_class(config.type)
# Set compile flag if needed
if hasattr(config, "compile_model"):
config.compile_model = args.use_torch_compile
policy = policy_class.from_pretrained(args.policy_path, config=config)
# Initialize RTC config
policy.config.rtc_config = RTCConfig(
execution_horizon=args.execution_horizon,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
policy = policy.to(args.device)
policy.eval()
logger.info(f"Policy loaded: {config.type}")
logger.info(f"Device: {args.device}")
logger.info(f"Execution horizon: {args.execution_horizon}")
# Create mock observation
logger.info("Creating mock observation...")
observation = create_mock_observation(policy, args.device)
# Enable detailed profiling if requested
if args.enable_detailed_profiling:
enable_profiling()
logger.info("Detailed profiling enabled")
# Profile without RTC
results_without_rtc = profile_inference(
policy=policy,
observation=observation,
num_iterations=args.num_iterations,
use_rtc=False,
execution_horizon=args.execution_horizon,
)
if args.enable_detailed_profiling:
logger.info("\nDetailed profiling stats (WITHOUT RTC):")
print_profiling_summary()
clear_profiling_stats()
# Profile with RTC
results_with_rtc = profile_inference(
policy=policy,
observation=observation,
num_iterations=args.num_iterations,
use_rtc=True,
execution_horizon=args.execution_horizon,
)
if args.enable_detailed_profiling:
logger.info("\nDetailed profiling stats (WITH RTC):")
print_profiling_summary()
# Compare results
compare_results(results_without_rtc, results_with_rtc)
if __name__ == "__main__":
main()
+2 -1
View File
@@ -98,6 +98,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -132,7 +133,7 @@ groot = [
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
+7
View File
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
class RTCAttentionSchedule(str, Enum):
ZEROS = "ZEROS"
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -15
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model."""
def __init__(self, config: PI0Config):
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI0 model
self.model = PI0Pytorch(config)
self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@PreTrainedConfig.register_subclass("pi05")
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -14
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config):
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI05 model
self.model = PI05Pytorch(config)
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+49
View File
@@ -0,0 +1,49 @@
# Real-Time Chunking (RTC) Module
This module implements Real-Time Chunking and related adaptive inference techniques for robotics policies in LeRobot.
## Overview
Real-Time Chunking (RTC) addresses the challenge of real-time inference in action chunking policies by treating chunk generation as an inpainting problem. It strategically handles overlapping timesteps between action chunks using prefix attention mechanisms.
It is particularly effective for handling long-horizon inference in robotics policies.
## Integration with Policies
RTC can be integrated with any policy that supports flow mathicng for chunking:
- **SmolVLA**: Vision-language-action model with RTC support
- **Pi0**: Action prediction model with adaptive chunking
- **Pi05**: Action prediction model with adaptive chunking
## Original Implementation
This implementation is based on Physical Intelligence's Kinetix RTC:
- [Original RTC implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214)
- [Kinetix GitHub Repository](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
## References
- [Real Time Chunking Paper](https://www.physicalintelligence.company/research/real_time_chunking)
- [Physical Intelligence Kinetix](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
## How to run
### Check with data from the dataset
```bash
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--seed=42
```
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
The example output should look like this:
![Flow Matching with RTC](./flow_matching.png)
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
+219
View File
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action queue management for Real-Time Chunking (RTC).
This module provides ActionQueue, a thread-safe queue for managing action chunks
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
handling action merging and leftover tracking.
"""
import logging
from threading import Lock
import torch
from torch import Tensor
from lerobot.policies.rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
class ActionQueue:
"""Thread-safe queue for managing action chunks in real-time control.
This queue handles two types of action sequences:
- Original actions: Used for RTC to compute leftovers from previous chunks
- Processed actions: Post-processed actions ready for robot execution
The queue operates in two modes:
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
Args:
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
Attributes:
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
last_index (int): Current consumption index in the queue.
"""
def __init__(self, cfg: RTCConfig):
"""Initialize the action queue.
Args:
cfg: RTC configuration controlling queue behavior.
"""
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
"""Get the next action from the queue.
Returns:
Tensor | None: The next action (action_dim,) or None if queue is empty.
Returns a clone to prevent external modifications.
"""
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
These are the unconsumed actions from the current chunk, which will be
used by RTC to compute corrections for the next chunk.
Returns:
Tensor | None: Remaining original actions (remaining_steps, action_dim),
or None if no original queue exists.
"""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
"""Merge new actions into the queue.
This method operates differently based on RTC mode:
- RTC enabled: Replaces the queue, accounting for inference delay
- RTC disabled: Appends to the queue, maintaining continuity
Args:
original_actions: Unprocessed actions from policy (time_steps, action_dim).
processed_actions: Post-processed actions for robot (time_steps, action_dim).
real_delay: Number of time steps of inference delay.
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
"""Replace the queue with new actions (RTC mode).
Discards the first `real_delay` actions since they correspond to the time
spent during inference, when the robot was executing previous actions.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
"""Append new actions to the queue (non-RTC mode).
Removes already-consumed actions and appends new ones, maintaining
queue continuity without replacement.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
"""
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
number of actions consumed during inference.
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
Based on:
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
"""
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
@dataclass
class RTCConfig:
"""Configuration for Real Time Chunking (RTC) inference.
RTC improves real-time inference by treating chunk generation as an inpainting problem,
strategically handling overlapping timesteps between action chunks using prefix attention.
"""
# Infrastructure
enabled: bool = False
# Core RTC settings
# Todo change to exp
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
max_guidance_weight: float = 10.0
execution_horizon: int = 10
# Debug settings
debug: bool = False
debug_maxlen: int = 100
def __post_init__(self):
"""Validate RTC configuration parameters."""
if self.max_guidance_weight <= 0:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
+233
View File
@@ -0,0 +1,233 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Debug information handler for Real-Time Chunking (RTC)."""
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
@dataclass
class DebugStep:
"""Container for debug information from a single denoising step.
Attributes:
step_idx (int): Step index/counter.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
time (float | Tensor | None): Time parameter.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
metadata (dict[str, Any]): Additional metadata.
"""
step_idx: int = 0
x_t: Tensor | None = None
v_t: Tensor | None = None
x1_t: Tensor | None = None
correction: Tensor | None = None
err: Tensor | None = None
weights: Tensor | None = None
guidance_weight: float | Tensor | None = None
time: float | Tensor | None = None
inference_delay: int | None = None
execution_horizon: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
"""Convert debug step to dictionary.
Args:
include_tensors (bool): If True, include tensor values. If False, only include
tensor statistics (shape, mean, std, min, max).
Returns:
Dictionary representation of the debug step.
"""
result = {
"step_idx": self.step_idx,
"guidance_weight": (
self.guidance_weight.item()
if isinstance(self.guidance_weight, Tensor)
else self.guidance_weight
),
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
"inference_delay": self.inference_delay,
"execution_horizon": self.execution_horizon,
"metadata": self.metadata.copy(),
}
# Add tensor information
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
for field_name in tensor_fields:
tensor = getattr(self, field_name)
if tensor is not None:
if include_tensors:
result[field_name] = tensor.detach().cpu()
else:
result[f"{field_name}_stats"] = {
"shape": tuple(tensor.shape),
"mean": tensor.mean().item(),
"std": tensor.std().item(),
"min": tensor.min().item(),
"max": tensor.max().item(),
}
return result
class Tracker:
"""Collects and manages debug information for RTC processing.
This tracker stores debug information from recent denoising steps in a dictionary,
using time as the key for efficient lookups and updates.
Args:
enabled (bool): Whether debug collection is enabled.
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
"""
def __init__(self, enabled: bool = False, maxlen: int = 100):
self.enabled = enabled
self._steps = {} if enabled else None # Dictionary with time as key
self._maxlen = maxlen
self._step_counter = 0
def reset(self) -> None:
"""Clear all recorded debug information."""
if self.enabled and self._steps is not None:
self._steps.clear()
self._step_counter = 0
@torch._dynamo.disable
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Track debug information for a denoising step at a given time.
If a step with the given time already exists, it will be updated with the new data.
Otherwise, a new step will be created. Only non-None fields are updated/set.
Note: This method is excluded from torch.compile to avoid graph breaks from
operations like .item() which are incompatible with compiled graphs.
Args:
time (float | Tensor): Time parameter - used as the key to identify the step.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction.
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
**metadata: Additional metadata to store.
"""
if not self.enabled:
return
# Convert time to float and round to avoid float precision issues
time_value = time.item() if isinstance(time, Tensor) else time
time_key = round(time_value, 6) # Use rounded time as dictionary key
# Check if step with this time already exists
if time_key in self._steps:
# Update existing step with non-None fields
existing_step = self._steps[time_key]
if x_t is not None:
existing_step.x_t = x_t.detach().clone()
if v_t is not None:
existing_step.v_t = v_t.detach().clone()
if x1_t is not None:
existing_step.x1_t = x1_t.detach().clone()
if correction is not None:
existing_step.correction = correction.detach().clone()
if err is not None:
existing_step.err = err.detach().clone()
if weights is not None:
existing_step.weights = weights.detach().clone()
if guidance_weight is not None:
existing_step.guidance_weight = guidance_weight
if inference_delay is not None:
existing_step.inference_delay = inference_delay
if execution_horizon is not None:
existing_step.execution_horizon = execution_horizon
if metadata:
existing_step.metadata.update(metadata)
else:
# Create new step
step = DebugStep(
step_idx=self._step_counter,
x_t=x_t.detach().clone() if x_t is not None else None,
v_t=v_t.detach().clone() if v_t is not None else None,
x1_t=x1_t.detach().clone() if x1_t is not None else None,
correction=correction.detach().clone() if correction is not None else None,
err=err.detach().clone() if err is not None else None,
weights=weights.detach().clone() if weights is not None else None,
guidance_weight=guidance_weight,
time=time_value,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
metadata=metadata,
)
# Add to dictionary
self._steps[time_key] = step
self._step_counter += 1
# Enforce maxlen if set
if self._maxlen is not None and len(self._steps) > self._maxlen:
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
oldest_key = next(iter(self._steps))
del self._steps[oldest_key]
def get_all_steps(self) -> list[DebugStep]:
"""Get all recorded debug steps.
Returns:
List of all DebugStep objects (may be empty if disabled).
"""
if not self.enabled or self._steps is None:
return []
return list(self._steps.values())
def __len__(self) -> int:
"""Return the number of recorded debug steps."""
if not self.enabled or self._steps is None:
return 0
return len(self._steps)
@@ -0,0 +1,117 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities for RTC debug information."""
import torch
class RTCDebugVisualizer:
"""Visualizer for RTC debug information.
This class provides methods to visualize debug information collected by the Tracker,
including corrections, errors, weights, and guidance weights over denoising steps.
"""
@staticmethod
def plot_waypoints(
axes,
tensor,
start_from: int = 0,
color: str = "blue",
label: str = "",
alpha: float = 0.7,
linewidth: float = 2,
marker: str | None = None,
markersize: int = 4,
):
"""Plot trajectories across multiple dimensions.
This function plots a tensor's values across time for multiple dimensions,
with each dimension plotted on a separate axis.
Args:
axes: Array of matplotlib axes (one for each dimension).
tensor: The tensor to plot (can be torch.Tensor or numpy array).
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
start_from: Starting index for the x-axis.
color: Color for the plot lines.
label: Label for the plot legend.
alpha: Transparency level for the plot.
linewidth: Width of the plot lines.
marker: Marker style for data points (e.g., 'o', 's', '^').
markersize: Size of the markers.
"""
import numpy as np
# Handle None tensor
if tensor is None:
return
# Convert tensor to numpy if needed
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
# Handle different tensor shapes
if tensor_np.ndim == 3:
# If batch dimension present, take first batch
tensor_np = tensor_np[0]
elif tensor_np.ndim == 1:
# If 1D, reshape to (time_steps, 1)
tensor_np = tensor_np.reshape(-1, 1)
# Get dimensions
time_steps, num_dims = tensor_np.shape
# Create x-axis indices
x_indices = np.arange(start_from, start_from + time_steps)
# Plot each dimension on its corresponding axis
num_axes = len(axes) if hasattr(axes, "__len__") else 1
for dim_idx in range(min(num_dims, num_axes)):
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
# Plot the trajectory
if marker:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
marker=marker,
markersize=markersize,
)
else:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
)
# Add grid and labels if not already present
if not ax.xaxis.get_label().get_text():
ax.set_xlabel("Step", fontsize=10)
if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Add legend if label provided and this is the first dimension
if label and dim_idx == 0:
ax.legend(loc="best", fontsize=8)
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
from collections import deque
import numpy as np
class LatencyTracker:
"""Tracks recent latencies and provides max/percentile queries.
Args:
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
"""
def __init__(self, maxlen: int = 100):
self._values = deque(maxlen=maxlen)
self.reset()
def reset(self) -> None:
"""Clear all recorded latencies."""
self._values.clear()
self.max_latency = 0.0
def add(self, latency: float) -> None:
"""Add a latency sample (seconds)."""
# Ensure numeric and non-negative
val = float(latency)
if val < 0:
return
self._values.append(val)
self.max_latency = max(self.max_latency, val)
def __len__(self) -> int:
return len(self._values)
def max(self) -> float | None:
"""Return the maximum latency or None if empty."""
return self.max_latency
def percentile(self, q: float) -> float | None:
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
if not self._values:
return 0.0
q = float(q)
if q <= 0.0:
return min(self._values)
if q >= 1.0:
return self.max_latency
vals = np.array(list(self._values), dtype=np.float32)
return float(np.quantile(vals, q))
def p95(self) -> float | None:
"""Return the 95th percentile latency or None if empty."""
return self.percentile(0.95)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real-Time Chunking (RTC) implementation for LeRobot.
Based on Physical Intelligence's Kinetix implementation:
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
"""
import logging
import math
import torch
from torch import Tensor
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_tracker import Tracker
logger = logging.getLogger(__name__)
class RTCProcessor:
"""Real-Time Chunking processor for action chunking policies.
This class implements RTC techniques including velocity calculation,
prefix attention, and adaptive chunk processing.
"""
def __init__(self, rtc_config: RTCConfig):
self.rtc_config = rtc_config
self.tracker = None
if rtc_config.debug:
self.tracker = Tracker(
enabled=rtc_config.debug,
maxlen=rtc_config.debug_maxlen,
)
# ====================== Tracker Proxy Methods ======================
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Proxy method to track debug information.
If tracker is None or disabled, this method does nothing.
Otherwise, it forwards the call to tracker.track().
"""
if self.tracker is not None:
self.tracker.track(
time=time,
x_t=x_t,
v_t=v_t,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
**metadata,
)
def get_all_debug_steps(self) -> list:
"""Get all debug steps from tracker.
Returns empty list if tracker is disabled or None.
"""
if self.tracker is not None:
return self.tracker.get_all_steps()
return []
def is_debug_enabled(self) -> bool:
"""Check if debug tracking is enabled.
Returns True if tracker exists and is enabled.
"""
return self.tracker is not None and self.tracker.enabled
def reset_tracker(self) -> None:
"""Reset the tracker, clearing all recorded steps.
Does nothing if tracker is None.
"""
if self.tracker is not None:
self.tracker.reset()
# ====================== End Tracker Proxy Methods ======================
def denoise_step(
self,
x_t,
prev_chunk_left_over,
inference_delay,
time,
original_denoise_step_partial,
execution_horizon=None,
) -> Tensor:
"""RTC guidance wrapper around an existing denoiser.
This method wraps an original denoising callable that only takes ``x_t`` and
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
(RTC) prefix guidance using the leftover prefix from the previous chunk.
Args:
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
is applied and the method returns ``v_t`` from the original denoiser.
inference_delay (int): Number of timesteps from the prefix to use for guidance.
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
broadcastable with ``x_t``.
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
computes the base denoised velocity given only ``x_t``.
execution_horizon (int | None): Horizon used to build prefix weights. If
``None``, defaults to ``self.rtc_config.execution_horizon``.
Returns:
Tensor: Guided velocity with the same shape as ``v_t``.
Notes:
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
right-padded with zeros to match ``T``.
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
and broadcast to ``(B, T, A)``.
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
``error = (prev_chunk_left_over - x1_t) * weights``.
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
Reference:
https://www.physicalintelligence.company/download/real_time_chunking.pdf
"""
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
v_t = original_denoise_step_partial(x_t)
return v_t
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
# Add batch dimension
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
# Add batch dimension
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
# because there is nothing to merge
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
assert prev_chunk_left_over.shape == x_t.shape, (
"The padded previous chunk must be the same size as the input tensor"
)
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
with torch.enable_grad():
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
x1_t = x_t - time * v_t # noqa: N806
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
result = v_t - guidance_weight * correction
# Remove the batch dimension if it was added
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def get_prefix_weights(self, start, end, total):
start = min(start, end)
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
weights = torch.zeros(total)
weights[:start] = 1.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
weights = torch.ones(total)
weights[end:] = 0.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
lin_weights = self._linweights(start, end, total)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
lin_weights = self._linweights(start, end, total)
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
return weights
def _linweights(self, start, end, total):
skip_steps_at_end = max(total - end, 0)
linspace_steps = total - skip_steps_at_end - start
if end <= start or linspace_steps <= 0:
return torch.tensor([])
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
def _add_trailing_zeros(self, weights, total, end):
zeros_len = total - end
if zeros_len <= 0:
return weights
zeros = torch.zeros(zeros_len)
return torch.cat([weights, zeros])
def _add_leading_ones(self, weights, start, total):
ones_len = min(start, total)
if ones_len <= 0:
return weights
ones = torch.ones(ones_len)
return torch.cat([ones, weights])
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
def __post_init__(self):
super().__post_init__()
+101 -19
View File
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.model = VLAFlowMatching(config)
self.init_rtc_processor()
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
self.reset()
def reset(self):
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Lets create processor if the config provided
# If RTC is not enabled - we still can track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
# During the normal initialization process the model is not created yet
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def _get_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
actions = self._get_action_chunk(batch, noise, **kwargs)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
if self._check_get_actions_condition():
actions = self._get_action_chunk(batch, noise)
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self._queues[ACTION].popleft()
def _check_get_actions_condition(self) -> bool:
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config: SmolVLAConfig):
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
+206
View File
@@ -0,0 +1,206 @@
"""
Profiling utilities for performance analysis.
Usage:
from lerobot.utils.profiling import profile_method, get_profiling_stats, print_profiling_summary
@profile_method
def my_slow_function(x):
return x * 2
# At end of execution:
print_profiling_summary()
"""
import functools
import logging
import time
from collections import defaultdict
from threading import Lock
from typing import Any, Callable
logger = logging.getLogger(__name__)
# Global profiling statistics storage
_profiling_stats: dict[str, list[float]] = defaultdict(list)
_profiling_lock = Lock()
_profiling_enabled = False
def enable_profiling():
"""Enable profiling globally."""
global _profiling_enabled
_profiling_enabled = True
logger.info("Profiling enabled")
def disable_profiling():
"""Disable profiling globally."""
global _profiling_enabled
_profiling_enabled = False
logger.info("Profiling disabled")
def is_profiling_enabled() -> bool:
"""Check if profiling is enabled."""
return _profiling_enabled
def record_timing(name: str, duration: float):
"""Record a timing measurement.
Args:
name: Name/identifier for this timing
duration: Duration in seconds
"""
if not _profiling_enabled:
return
with _profiling_lock:
_profiling_stats[name].append(duration)
def profile_method(func: Callable) -> Callable:
"""Decorator to profile a method or function.
Args:
func: Function to profile
Returns:
Wrapped function that records execution time
"""
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
if not _profiling_enabled:
return func(*args, **kwargs)
start = time.perf_counter()
try:
result = func(*args, **kwargs)
return result
finally:
duration = time.perf_counter() - start
# Use fully qualified name
name = f"{func.__module__}.{func.__qualname__}"
record_timing(name, duration)
return wrapper
class ProfileContext:
"""Context manager for profiling code blocks.
Usage:
with ProfileContext("my_operation"):
# ... code to profile ...
"""
def __init__(self, name: str):
self.name = name
self.start = None
def __enter__(self):
if _profiling_enabled:
self.start = time.perf_counter()
return self
def __exit__(self, *args):
if _profiling_enabled and self.start is not None:
duration = time.perf_counter() - self.start
record_timing(self.name, duration)
def get_profiling_stats() -> dict[str, dict[str, float]]:
"""Get summary statistics for all profiled functions.
Returns:
Dictionary mapping function names to their stats (count, mean, min, max, total)
"""
with _profiling_lock:
summary = {}
for name, times in _profiling_stats.items():
if times:
summary[name] = {
"count": len(times),
"mean": sum(times) / len(times),
"min": min(times),
"max": max(times),
"total": sum(times),
"mean_ms": (sum(times) / len(times)) * 1000,
"min_ms": min(times) * 1000,
"max_ms": max(times) * 1000,
}
return summary
def clear_profiling_stats():
"""Clear all profiling statistics."""
with _profiling_lock:
_profiling_stats.clear()
logger.info("Profiling stats cleared")
def print_profiling_summary(sort_by: str = "total"):
"""Print formatted summary of profiling statistics.
Args:
sort_by: Sort key ('total', 'mean', 'count', 'max')
"""
summary = get_profiling_stats()
if not summary:
logger.info("No profiling data available")
return
logger.info("\n" + "=" * 100)
logger.info("PROFILING SUMMARY")
logger.info("=" * 100)
# Sort by requested key
sorted_items = sorted(summary.items(), key=lambda x: x[1].get(sort_by, 0), reverse=True)
# Print header
logger.info(
f"{'Function':<60} {'Count':>8} {'Mean (ms)':>12} {'Min (ms)':>12} {'Max (ms)':>12} {'Total (s)':>12}"
)
logger.info("-" * 100)
# Print each function's stats
for name, stats in sorted_items:
# Shorten long names
display_name = name if len(name) <= 60 else "..." + name[-57:]
logger.info(
f"{display_name:<60} "
f"{stats['count']:>8} "
f"{stats['mean_ms']:>12.2f} "
f"{stats['min_ms']:>12.2f} "
f"{stats['max_ms']:>12.2f} "
f"{stats['total']:>12.2f}"
)
logger.info("=" * 100)
# Print summary
total_time = sum(s["total"] for s in summary.values())
total_calls = sum(s["count"] for s in summary.values())
logger.info(f"\nTotal profiled time: {total_time:.2f}s across {total_calls} calls")
logger.info("=" * 100 + "\n")
def profile_section(name: str):
"""Return a context manager for profiling a code section.
Args:
name: Name for this section
Returns:
ProfileContext instance
Usage:
with profile_section("data_loading"):
data = load_data()
"""
return ProfileContext(name)
+336
View File
@@ -0,0 +1,336 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test PI0.5 policy with Real-Time Chunking (RTC) enabled during inference."""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_pi05_rtc_initialization():
"""Test PI0.5 policy can initialize RTC processor."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = PI05Policy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ PI0.5 RTC initialization: Test passed")
@require_cuda
def test_pi05_rtc_initialization_without_rtc_config():
"""Test PI0.5 policy can initialize without RTC config."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Instantiate policy
policy = PI05Policy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ PI0.5 RTC initialization without RTC config: Test passed")
@require_cuda
def test_pi05_rtc_inference_with_prev_chunk():
"""Test PI0.5 policy inference with RTC and previous chunk."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ PI0.5 RTC inference with prev_chunk: Test passed")
@require_cuda
def test_pi05_rtc_inference_without_prev_chunk():
"""Test PI0.5 policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ PI0.5 RTC inference without prev_chunk: Test passed")
@require_cuda
def test_pi05_rtc_validation_rules():
"""Test PI0.5 policy with RTC follows all three validation rules."""
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats (PI0.5 uses QUANTILES normalization)
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"q01": -torch.ones(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"q01": -torch.ones(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI05Policy(config)
policy.eval()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
+378
View File
@@ -0,0 +1,378 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference."""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_pi0_rtc_initialization():
"""Test PI0 policy can initialize RTC processor."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = PI0Policy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ PI0 RTC initialization: Test passed")
@require_cuda
def test_pi0_rtc_initialization_without_rtc_config():
"""Test PI0 policy can initialize without RTC config."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Instantiate policy
policy = PI0Policy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ PI0 RTC initialization without RTC config: Test passed")
def test_pi0_rtc_inference_with_prev_chunk():
"""Test PI0 policy inference with RTC and previous chunk."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ PI0 RTC inference with prev_chunk: Test passed")
@require_cuda
def test_pi0_rtc_inference_without_prev_chunk():
"""Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ PI0 RTC inference without prev_chunk: Test passed")
@require_cuda
def test_pi0_rtc_validation_rules():
"""Test PI0 policy with RTC follows all three validation rules."""
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and preprocessor
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
"""Test PI0 with different RTC attention schedules."""
set_seed(42)
schedules = [
RTCAttentionSchedule.ZEROS,
RTCAttentionSchedule.ONES,
RTCAttentionSchedule.LINEAR,
RTCAttentionSchedule.EXP,
]
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
device = config.device
for schedule in schedules:
print(f"Testing schedule: {schedule}")
# Add RTC config with specific schedule
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=schedule,
debug=False,
)
# Instantiate policy
policy = PI0Policy(config)
policy.eval()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
actions = policy.predict_action_chunk(
batch,
noise=noise,
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Verify shape
assert actions.shape == (1, config.chunk_size, 7)
print(f" ✓ Schedule {schedule}: Test passed")
print("✓ PI0 RTC different schedules: All schedules tested")
+825
View File
@@ -0,0 +1,825 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC ActionQueue module."""
import threading
import time
import pytest
import torch
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Fixtures ======================
@pytest.fixture
def rtc_config_enabled():
"""Create an RTC config with RTC enabled."""
return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0)
@pytest.fixture
def rtc_config_disabled():
"""Create an RTC config with RTC disabled."""
return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0)
@pytest.fixture
def sample_actions():
"""Create sample action tensors for testing."""
return {
"original": torch.randn(50, 6), # (time_steps, action_dim)
"processed": torch.randn(50, 6),
"short": torch.randn(10, 6),
"longer": torch.randn(100, 6),
}
@pytest.fixture
def action_queue_rtc_enabled(rtc_config_enabled):
"""Create an ActionQueue with RTC enabled."""
return ActionQueue(rtc_config_enabled)
@pytest.fixture
def action_queue_rtc_disabled(rtc_config_disabled):
"""Create an ActionQueue with RTC disabled."""
return ActionQueue(rtc_config_disabled)
# ====================== Initialization Tests ======================
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
"""Test ActionQueue initializes correctly with RTC enabled."""
queue = ActionQueue(rtc_config_enabled)
assert queue.queue is None
assert queue.original_queue is None
assert queue.last_index == 0
assert queue.cfg.enabled is True
def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
"""Test ActionQueue initializes correctly with RTC disabled."""
queue = ActionQueue(rtc_config_disabled)
assert queue.queue is None
assert queue.original_queue is None
assert queue.last_index == 0
assert queue.cfg.enabled is False
# ====================== get() Tests ======================
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
"""Test get() returns None when queue is empty."""
action = action_queue_rtc_enabled.get()
assert action is None
def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions):
"""Test get() returns actions in sequence."""
# Initialize queue with actions
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Get first action
action1 = action_queue_rtc_enabled.get()
assert action1 is not None
assert action1.shape == (6,)
assert torch.equal(action1, sample_actions["processed"][0])
# Get second action
action2 = action_queue_rtc_enabled.get()
assert action2 is not None
assert torch.equal(action2, sample_actions["processed"][1])
def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test get() returns None after all actions are consumed."""
# Use short action sequence
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all actions
for _ in range(10):
action = action_queue_rtc_enabled.get()
assert action is not None
# Next get should return None
action = action_queue_rtc_enabled.get()
assert action is None
def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
"""Test get() increments last_index correctly."""
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_enabled.last_index == 0
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 1
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 2
# ====================== qsize() Tests ======================
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
"""Test qsize() returns 0 when queue is empty."""
assert action_queue_rtc_enabled.qsize() == 0
def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions):
"""Test qsize() returns correct number of remaining actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == 10
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 9
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 8
def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test qsize() returns 0 after queue is exhausted."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all actions
for _ in range(10):
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 0
# ====================== empty() Tests ======================
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
"""Test empty() returns True when queue is empty."""
assert action_queue_rtc_enabled.empty() is True
def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns False when queue has actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.empty() is False
def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns False after partial consumption."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.empty() is False
def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
"""Test empty() returns True after all actions consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all
for _ in range(10):
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.empty() is True
# ====================== get_action_index() Tests ======================
def test_get_action_index_initial_value(action_queue_rtc_enabled):
"""Test get_action_index() returns 0 initially."""
assert action_queue_rtc_enabled.get_action_index() == 0
def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions):
"""Test get_action_index() tracks consumption correctly."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.get_action_index() == 0
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.get_action_index() == 1
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.get_action_index() == 3
# ====================== get_left_over() Tests ======================
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
"""Test get_left_over() returns None when queue is empty."""
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is None
def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns all original actions when none consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (10, 6)
assert torch.equal(leftover, sample_actions["short"])
def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns only remaining original actions."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 3 actions
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (7, 6)
assert torch.equal(leftover, sample_actions["short"][3:])
def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() returns empty tensor after all consumed."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume all
for _ in range(10):
action_queue_rtc_enabled.get()
leftover = action_queue_rtc_enabled.get_left_over()
assert leftover is not None
assert leftover.shape == (0, 6)
# ====================== merge() with RTC Enabled Tests ======================
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
"""Test merge() replaces queue when RTC is enabled."""
# Add initial actions
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == 10
# Consume some actions
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.qsize() == 8
# Merge new actions - should replace, not append
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
# Queue should be replaced with new actions minus delay
# Original has 50 actions, delay is 5, so remaining is 45
assert action_queue_rtc_enabled.qsize() == 45
assert action_queue_rtc_enabled.get_action_index() == 0
def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() correctly applies real_delay when RTC is enabled."""
delay = 10
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay)
# Queue should have original length minus delay
expected_size = len(sample_actions["original"]) - delay
assert action_queue_rtc_enabled.qsize() == expected_size
# First action should be the one at index [delay]
first_action = action_queue_rtc_enabled.get()
assert torch.equal(first_action, sample_actions["processed"][delay])
def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
"""Test merge() resets last_index to 0 when RTC is enabled."""
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_enabled.get()
action_queue_rtc_enabled.get()
assert action_queue_rtc_enabled.last_index == 2
# Merge new actions
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
assert action_queue_rtc_enabled.last_index == 0
def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() with zero delay keeps all actions."""
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"])
def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
"""Test merge() with delay larger than action sequence."""
# Delay is larger than sequence length
delay = 100
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay)
# Queue should be empty (delay >= length)
assert action_queue_rtc_enabled.qsize() == 0
# ====================== merge() with RTC Disabled Tests ======================
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() appends actions when RTC is disabled."""
# Add initial actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
initial_size = action_queue_rtc_disabled.qsize()
assert initial_size == 10
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Should have appended
assert action_queue_rtc_disabled.qsize() == initial_size + 10
def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions):
"""Test merge() removes consumed actions before appending when RTC is disabled."""
# Add initial actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == 10
# Consume 3 actions
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
assert action_queue_rtc_disabled.qsize() == 7
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Should have 7 remaining + 10 new = 17
assert action_queue_rtc_disabled.qsize() == 17
def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions):
"""Test merge() resets last_index after appending when RTC is disabled."""
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
action_queue_rtc_disabled.get()
action_queue_rtc_disabled.get()
assert action_queue_rtc_disabled.last_index == 2
# Merge more actions
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# last_index should be reset to 0
assert action_queue_rtc_disabled.last_index == 0
def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() ignores real_delay parameter when RTC is disabled."""
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10)
# All actions should be in queue (delay ignored)
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions):
"""Test merge() on first call with RTC disabled."""
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
assert action_queue_rtc_disabled.last_index == 0
# ====================== merge() with Different Action Shapes Tests ======================
def test_merge_with_different_action_dims():
"""Test merge() handles actions with different dimensions."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
# Actions with 4 dimensions instead of 6
actions_4d = torch.randn(20, 4)
queue.merge(actions_4d, actions_4d, real_delay=5)
action = queue.get()
assert action.shape == (4,)
def test_merge_with_different_lengths():
"""Test merge() handles action sequences of varying lengths."""
cfg = RTCConfig(enabled=False, execution_horizon=10)
queue = ActionQueue(cfg)
# Add sequences of different lengths
queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0)
assert queue.qsize() == 10
queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0)
assert queue.qsize() == 35
# ====================== merge() Delay Validation Tests ======================
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() validates that real_delay matches action index difference."""
import logging
caplog.set_level(logging.WARNING)
# Initialize queue
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 5 actions
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge with mismatched delay (should log warning)
# We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0,
# then indexes_diff=5, but if real_delay=3, it will warn
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=3,
action_index_before_inference=0,
)
# Check warning was logged
assert "Indexes diff is not equal to real delay" in caplog.text
def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() doesn't warn when delays are consistent."""
import logging
caplog.set_level(logging.WARNING)
# Initialize queue
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume 5 actions
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge with matching delay
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=5,
action_index_before_inference=0,
)
# Should not have warning
assert "Indexes diff is not equal to real delay" not in caplog.text
def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog):
"""Test merge() skips delay validation when action_index_before_inference is None."""
import logging
caplog.set_level(logging.WARNING)
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
for _ in range(5):
action_queue_rtc_enabled.get()
# Pass None for action_index_before_inference
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=999, # Doesn't matter
action_index_before_inference=None,
)
# Should not warn (validation skipped)
assert "Indexes diff is not equal to real delay" not in caplog.text
# ====================== Thread Safety Tests ======================
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
"""Test get() is thread-safe with multiple consumers."""
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
results = []
errors = []
def consumer():
try:
for _ in range(25):
action = action_queue_rtc_enabled.get()
if action is not None:
results.append(action)
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=consumer) for _ in range(4)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have consumed all actions (100 total, 4 threads * 25 each)
assert len(results) == 100
# All results should be unique (no duplicate consumption)
# We can verify by checking that indices are not duplicated
# Since we don't track indices in results, we check total count is correct
assert action_queue_rtc_enabled.qsize() == 0
def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions):
"""Test merge() is thread-safe with multiple producers."""
errors = []
def producer():
try:
for _ in range(5):
action_queue_rtc_disabled.merge(
sample_actions["short"], sample_actions["short"], real_delay=0
)
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=producer) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150)
assert action_queue_rtc_disabled.qsize() == 150
def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
"""Test concurrent get() and merge() operations."""
errors = []
consumed_count = [0]
def consumer():
try:
for _ in range(50):
action = action_queue_rtc_disabled.get()
if action is not None:
consumed_count[0] += 1
time.sleep(0.001)
except Exception as e:
errors.append(e)
def producer():
try:
for _ in range(10):
action_queue_rtc_disabled.merge(
sample_actions["short"], sample_actions["short"], real_delay=0
)
time.sleep(0.005)
except Exception as e:
errors.append(e)
consumer_threads = [threading.Thread(target=consumer) for _ in range(2)]
producer_threads = [threading.Thread(target=producer) for _ in range(2)]
for t in consumer_threads + producer_threads:
t.start()
for t in consumer_threads + producer_threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Should have consumed some or all actions (non-deterministic due to timing)
# Total produced: 2 producers * 10 merges * 10 actions = 200
# Total consumed attempts: 2 consumers * 50 = 100
assert consumed_count[0] <= 200
# ====================== get_left_over() Thread Safety Tests ======================
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
"""Test get_left_over() is thread-safe with concurrent access."""
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
errors = []
leftovers = []
def reader():
try:
for _ in range(20):
leftover = action_queue_rtc_enabled.get_left_over()
if leftover is not None:
leftovers.append(leftover.shape[0])
time.sleep(0.001)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=reader) for _ in range(3)]
# Also consume some actions concurrently
def consumer():
try:
for _ in range(10):
action_queue_rtc_enabled.get()
time.sleep(0.002)
except Exception as e:
errors.append(e)
consumer_thread = threading.Thread(target=consumer)
all_threads = threads + [consumer_thread]
for t in all_threads:
t.start()
for t in all_threads:
t.join()
# Should not have errors
assert len(errors) == 0
# Leftovers should be monotonically decreasing or stable
# (as actions are consumed, leftover size decreases)
assert len(leftovers) > 0
# ====================== Edge Cases Tests ======================
def test_queue_with_single_action(action_queue_rtc_enabled):
"""Test queue behavior with a single action."""
single_action_original = torch.randn(1, 6)
single_action_processed = torch.randn(1, 6)
action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0)
assert action_queue_rtc_enabled.qsize() == 1
action = action_queue_rtc_enabled.get()
assert action is not None
assert action.shape == (6,)
assert action_queue_rtc_enabled.qsize() == 0
def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions):
"""Test queue maintains correct state through multiple merge cycles."""
for _ in range(5):
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
# Consume half
for _ in range(5):
action_queue_rtc_enabled.get()
# Merge again
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3)
assert action_queue_rtc_enabled.qsize() > 0
def test_queue_with_all_zeros_actions(action_queue_rtc_enabled):
"""Test queue handles all-zero action tensors."""
zeros_actions = torch.zeros(20, 6)
action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0)
action = action_queue_rtc_enabled.get()
assert torch.all(action == 0)
def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions):
"""Test that merge() clones input tensors, not storing references."""
original_copy = sample_actions["original"].clone()
processed_copy = sample_actions["processed"].clone()
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Modify original tensors
sample_actions["original"].fill_(999.0)
sample_actions["processed"].fill_(-999.0)
# Queue should have cloned values
action = action_queue_rtc_enabled.get()
assert not torch.equal(action, sample_actions["processed"][0])
assert torch.equal(action, processed_copy[0])
leftover = action_queue_rtc_enabled.get_left_over()
assert not torch.equal(leftover, sample_actions["original"][1:])
assert torch.equal(leftover, original_copy[1:])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_queue_handles_gpu_tensors():
"""Test queue correctly handles GPU tensors."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
actions_gpu = torch.randn(20, 6, device="cuda")
queue.merge(actions_gpu, actions_gpu, real_delay=0)
action = queue.get()
assert action.device.type == "cuda"
leftover = queue.get_left_over()
assert leftover.device.type == "cuda"
def test_queue_handles_different_dtypes():
"""Test queue handles actions with different dtypes."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
# Use float64 instead of default float32
actions_f64 = torch.randn(20, 6, dtype=torch.float64)
queue.merge(actions_f64, actions_f64, real_delay=0)
action = queue.get()
assert action.dtype == torch.float64
def test_empty_with_none_queue(action_queue_rtc_enabled):
"""Test empty() correctly handles None queue."""
assert action_queue_rtc_enabled.queue is None
assert action_queue_rtc_enabled.empty() is True
def test_qsize_with_none_queue(action_queue_rtc_enabled):
"""Test qsize() correctly handles None queue."""
assert action_queue_rtc_enabled.queue is None
assert action_queue_rtc_enabled.qsize() == 0
# ====================== Integration Tests ======================
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
"""Test a typical RTC workflow: merge, consume, merge with delay."""
# First inference
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
initial_size = action_queue_rtc_enabled.qsize()
assert initial_size == 50
# Consume 10 actions (execution_horizon)
for _ in range(10):
action = action_queue_rtc_enabled.get()
assert action is not None
assert action_queue_rtc_enabled.qsize() == 40
# Second inference with delay
action_index_before = action_queue_rtc_enabled.get_action_index()
action_queue_rtc_enabled.merge(
sample_actions["original"],
sample_actions["processed"],
real_delay=5,
action_index_before_inference=action_index_before,
)
# Queue should be replaced, minus delay
assert action_queue_rtc_enabled.qsize() == 45
assert action_queue_rtc_enabled.get_action_index() == 0
def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions):
"""Test a typical non-RTC workflow: merge, consume, merge again."""
# First inference
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
assert action_queue_rtc_disabled.qsize() == 50
# Consume 40 actions
for _ in range(40):
action = action_queue_rtc_disabled.get()
assert action is not None
assert action_queue_rtc_disabled.qsize() == 10
# Second inference (should append)
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
# Should have 10 remaining + 50 new = 60
assert action_queue_rtc_disabled.qsize() == 60
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC configuration module."""
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Initialization Tests ======================
def test_rtc_config_default_initialization():
"""Test RTCConfig initializes with default values."""
config = RTCConfig()
assert config.enabled is False
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
assert config.max_guidance_weight == 10.0
assert config.execution_horizon == 10
assert config.debug is False
assert config.debug_maxlen == 100
def test_rtc_config_custom_initialization():
"""Test RTCConfig initializes with custom values."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
max_guidance_weight=5.0,
execution_horizon=20,
debug=True,
debug_maxlen=200,
)
assert config.enabled is True
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
assert config.max_guidance_weight == 5.0
assert config.execution_horizon == 20
assert config.debug is True
assert config.debug_maxlen == 200
def test_rtc_config_partial_initialization():
"""Test RTCConfig with partial custom values."""
config = RTCConfig(enabled=True, max_guidance_weight=15.0)
assert config.enabled is True
assert config.max_guidance_weight == 15.0
# Other values should be defaults
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
assert config.execution_horizon == 10
assert config.debug is False
+488
View File
@@ -0,0 +1,488 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC debug tracker module."""
import pytest
import torch
from lerobot.policies.rtc.debug_tracker import DebugStep, Tracker
# ====================== Fixtures ======================
@pytest.fixture
def sample_tensors():
"""Create sample tensors for testing."""
return {
"x_t": torch.randn(1, 50, 6),
"v_t": torch.randn(1, 50, 6),
"x1_t": torch.randn(1, 50, 6),
"correction": torch.randn(1, 50, 6),
"err": torch.randn(1, 50, 6),
"weights": torch.randn(1, 50, 1),
}
@pytest.fixture
def enabled_tracker():
"""Create an enabled tracker with default settings."""
return Tracker(enabled=True, maxlen=100)
@pytest.fixture
def disabled_tracker():
"""Create a disabled tracker."""
return Tracker(enabled=False)
# ====================== DebugStep Tests ======================
def test_debug_step_initialization():
"""Test that DebugStep can be initialized with default values."""
step = DebugStep()
assert step.step_idx == 0
assert step.x_t is None
assert step.v_t is None
assert step.x1_t is None
assert step.correction is None
assert step.err is None
assert step.weights is None
assert step.guidance_weight is None
assert step.time is None
assert step.inference_delay is None
assert step.execution_horizon is None
assert step.metadata == {}
def test_debug_step_with_values(sample_tensors):
"""Test DebugStep initialization with actual values."""
step = DebugStep(
step_idx=5,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
x1_t=sample_tensors["x1_t"],
correction=sample_tensors["correction"],
err=sample_tensors["err"],
weights=sample_tensors["weights"],
guidance_weight=2.5,
time=0.8,
inference_delay=4,
execution_horizon=8,
metadata={"custom_key": "custom_value"},
)
assert step.step_idx == 5
assert torch.equal(step.x_t, sample_tensors["x_t"])
assert torch.equal(step.v_t, sample_tensors["v_t"])
assert torch.equal(step.x1_t, sample_tensors["x1_t"])
assert torch.equal(step.correction, sample_tensors["correction"])
assert torch.equal(step.err, sample_tensors["err"])
assert torch.equal(step.weights, sample_tensors["weights"])
assert step.guidance_weight == 2.5
assert step.time == 0.8
assert step.inference_delay == 4
assert step.execution_horizon == 8
assert step.metadata == {"custom_key": "custom_value"}
def test_debug_step_to_dict_without_tensors(sample_tensors):
"""Test converting DebugStep to dictionary without tensor values."""
step = DebugStep(
step_idx=3,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=torch.tensor(3.0),
time=torch.tensor(0.5),
inference_delay=2,
execution_horizon=10,
)
result = step.to_dict(include_tensors=False)
assert result["step_idx"] == 3
assert result["guidance_weight"] == 3.0
assert result["time"] == 0.5
assert result["inference_delay"] == 2
assert result["execution_horizon"] == 10
# Check tensor statistics are included
assert "x_t_stats" in result
assert "v_t_stats" in result
assert "x1_t_stats" not in result # x1_t was None
# Verify statistics structure
assert "shape" in result["x_t_stats"]
assert "mean" in result["x_t_stats"]
assert "std" in result["x_t_stats"]
assert "min" in result["x_t_stats"]
assert "max" in result["x_t_stats"]
# Verify shape matches original tensor
assert result["x_t_stats"]["shape"] == tuple(sample_tensors["x_t"].shape)
def test_debug_step_to_dict_with_tensors(sample_tensors):
"""Test converting DebugStep to dictionary with tensor values."""
step = DebugStep(
step_idx=1,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=1.5,
time=0.9,
)
result = step.to_dict(include_tensors=True)
assert result["step_idx"] == 1
assert result["guidance_weight"] == 1.5
assert result["time"] == 0.9
# Check tensors are included (as CPU tensors)
assert "x_t" in result
assert "v_t" in result
assert isinstance(result["x_t"], torch.Tensor)
assert isinstance(result["v_t"], torch.Tensor)
assert result["x_t"].device.type == "cpu"
assert result["v_t"].device.type == "cpu"
def test_debug_step_to_dict_with_none_guidance_weight():
"""Test to_dict handles None guidance_weight correctly."""
step = DebugStep(step_idx=0, time=1.0, guidance_weight=None)
result = step.to_dict(include_tensors=False)
assert result["guidance_weight"] is None
def test_tracker_initialization_enabled():
"""Test tracker initialization when enabled."""
tracker = Tracker(enabled=True, maxlen=50)
assert tracker.enabled is True
assert tracker._steps == {}
assert tracker._maxlen == 50
assert tracker._step_counter == 0
assert len(tracker) == 0
def test_tracker_reset_when_enabled(enabled_tracker, sample_tensors):
"""Test reset clears all steps when tracker is enabled."""
# Add some steps
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
# Reset
enabled_tracker.reset()
assert len(enabled_tracker) == 0
assert enabled_tracker._step_counter == 0
assert enabled_tracker._steps == {}
def test_tracker_reset_when_disabled(disabled_tracker):
"""Test reset on disabled tracker doesn't cause errors."""
disabled_tracker.reset()
assert len(disabled_tracker) == 0
# ====================== Tracker.track() Tests ======================
def test_track_creates_new_step(enabled_tracker, sample_tensors):
"""Test that track creates a new step when time doesn't exist."""
enabled_tracker.track(
time=1.0,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=5.0,
inference_delay=4,
execution_horizon=8,
)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].step_idx == 0
assert steps[0].time == 1.0
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
assert steps[0].guidance_weight == 5.0
assert steps[0].inference_delay == 4
assert steps[0].execution_horizon == 8
def test_track_updates_existing_step(enabled_tracker, sample_tensors):
"""Test that track updates an existing step at the same time."""
# Create initial step
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert steps[0].v_t is None
# Update the same timestep with v_t
enabled_tracker.track(time=0.9, v_t=sample_tensors["v_t"])
assert len(enabled_tracker) == 1 # Still only one step
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Original x_t preserved
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # New v_t added
def test_track_with_tensor_time(enabled_tracker, sample_tensors):
"""Test track handles tensor time values correctly."""
time_tensor = torch.tensor(0.8)
enabled_tracker.track(time=time_tensor, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert abs(steps[0].time - 0.8) < 1e-6 # Use approximate comparison for floating point
def test_track_time_rounding(enabled_tracker, sample_tensors):
"""Test that track rounds time to avoid floating point precision issues."""
# These times should be treated as the same after rounding to 6 decimals
enabled_tracker.track(time=0.9000001, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9000002, v_t=sample_tensors["v_t"])
# Should still be one step (times rounded to same value)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
def test_track_does_nothing_when_disabled(disabled_tracker, sample_tensors):
"""Test that track does nothing when tracker is disabled."""
disabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(disabled_tracker) == 0
def test_track_with_metadata(enabled_tracker, sample_tensors):
"""Test track stores custom metadata."""
enabled_tracker.track(time=0.7, x_t=sample_tensors["x_t"], custom_field="custom_value", count=42)
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["custom_field"] == "custom_value"
assert steps[0].metadata["count"] == 42
def test_track_updates_metadata(enabled_tracker):
"""Test that track updates metadata for existing steps."""
enabled_tracker.track(time=0.6, meta1="value1")
enabled_tracker.track(time=0.6, meta2="value2")
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["meta1"] == "value1"
assert steps[0].metadata["meta2"] == "value2"
def test_track_clones_tensors(enabled_tracker, sample_tensors):
"""Test that track clones tensors instead of storing references."""
x_t_original = sample_tensors["x_t"].clone()
enabled_tracker.track(time=0.5, x_t=sample_tensors["x_t"])
# Modify original tensor
sample_tensors["x_t"].fill_(999.0)
# Tracked tensor should not be affected
steps = enabled_tracker.get_all_steps()
assert not torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].x_t, x_t_original)
def test_track_with_none_values(enabled_tracker):
"""Test track handles None values correctly."""
enabled_tracker.track(
time=0.4,
x_t=None,
v_t=None,
guidance_weight=None,
inference_delay=None,
)
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].x_t is None
assert steps[0].v_t is None
assert steps[0].guidance_weight is None
assert steps[0].inference_delay is None
def test_track_updates_only_non_none_fields(enabled_tracker, sample_tensors):
"""Test that update preserves existing values when None is passed."""
# Create step with x_t
enabled_tracker.track(time=0.3, x_t=sample_tensors["x_t"], guidance_weight=2.0)
# Update with v_t only (pass None for other fields)
enabled_tracker.track(time=0.3, v_t=sample_tensors["v_t"], x_t=None, guidance_weight=None)
# Original values should be preserved
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Still has x_t
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # Now has v_t
assert steps[0].guidance_weight == 2.0 # Still has guidance_weight
# ====================== Tracker.maxlen Tests ======================
def test_tracker_enforces_maxlen():
"""Test that tracker enforces maxlen limit."""
tracker = Tracker(enabled=True, maxlen=3)
# Add 5 steps
for i in range(5):
time = 1.0 - i * 0.1 # 1.0, 0.9, 0.8, 0.7, 0.6
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should only keep the last 3
assert len(tracker) == 3
# Verify oldest steps were removed (should have 0.6, 0.7, 0.8)
steps = tracker.get_all_steps()
times = sorted([step.time for step in steps])
assert times == [0.6, 0.7, 0.8]
def test_tracker_step_idx_increments_despite_maxlen():
"""Test that step_idx continues incrementing even when maxlen is enforced."""
tracker = Tracker(enabled=True, maxlen=2)
# Add 4 steps
for i in range(4):
time = 1.0 - i * 0.1
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should have 2 steps with step_idx 2 and 3 (oldest removed)
steps = sorted(tracker.get_all_steps(), key=lambda s: s.step_idx)
assert len(steps) == 2
assert steps[0].step_idx == 2
assert steps[1].step_idx == 3
def test_tracker_without_maxlen_keeps_all():
"""Test that tracker without maxlen keeps all steps."""
tracker = Tracker(enabled=True, maxlen=None)
# Add 100 steps
for i in range(100):
time = 1.0 - i * 0.01
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
assert len(tracker) == 100
def test_get_all_steps_returns_empty_when_disabled(disabled_tracker):
"""Test get_all_steps returns empty list when disabled."""
steps = disabled_tracker.get_all_steps()
assert steps == []
assert isinstance(steps, list)
def test_get_all_steps_returns_empty_when_no_steps(enabled_tracker):
"""Test get_all_steps returns empty list when no steps tracked."""
steps = enabled_tracker.get_all_steps()
assert steps == []
def test_get_all_steps_returns_all_tracked_steps(enabled_tracker, sample_tensors):
"""Test get_all_steps returns all tracked steps."""
# Track 5 steps
for i in range(5):
time = 1.0 - i * 0.1
enabled_tracker.track(time=time, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 5
# Verify all are DebugStep instances
for step in steps:
assert isinstance(step, DebugStep)
def test_get_all_steps_preserves_insertion_order(enabled_tracker):
"""Test that get_all_steps preserves insertion order (Python 3.7+)."""
times = [0.9, 0.8, 0.7, 0.6, 0.5]
for time in times:
enabled_tracker.track(time=time, x_t=torch.randn(1, 10, 6))
steps = enabled_tracker.get_all_steps()
retrieved_times = [step.time for step in steps]
# Should be in insertion order
assert retrieved_times == times
# ====================== Tracker.__len__() Tests ======================
def test_len_returns_zero_when_disabled(disabled_tracker):
"""Test __len__ returns 0 when tracker is disabled."""
assert len(disabled_tracker) == 0
def test_len_returns_zero_when_empty(enabled_tracker):
"""Test __len__ returns 0 when no steps are tracked."""
assert len(enabled_tracker) == 0
def test_len_returns_correct_count(enabled_tracker, sample_tensors):
"""Test __len__ returns correct number of tracked steps."""
assert len(enabled_tracker) == 0
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.track(time=0.8, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 3
def test_len_after_reset(enabled_tracker, sample_tensors):
"""Test __len__ returns 0 after reset."""
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.reset()
assert len(enabled_tracker) == 0
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_tracker_handles_gpu_tensors():
"""Test tracker correctly handles GPU tensors."""
tracker = Tracker(enabled=True, maxlen=10)
x_t_gpu = torch.randn(1, 50, 6, device="cuda")
tracker.track(time=1.0, x_t=x_t_gpu)
steps = tracker.get_all_steps()
# Tracker should clone and detach tensors
assert steps[0].x_t.device.type == "cuda"
def test_tracker_with_varying_tensor_shapes(enabled_tracker):
"""Test tracker handles varying tensor shapes across steps."""
enabled_tracker.track(time=1.0, x_t=torch.randn(1, 50, 6))
enabled_tracker.track(time=0.9, x_t=torch.randn(1, 25, 6))
enabled_tracker.track(time=0.8, x_t=torch.randn(2, 50, 8))
steps = enabled_tracker.get_all_steps()
assert len(steps) == 3
assert steps[0].x_t.shape == (1, 50, 6)
assert steps[1].x_t.shape == (1, 25, 6)
assert steps[2].x_t.shape == (2, 50, 8)
+322
View File
@@ -0,0 +1,322 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC LatencyTracker module."""
import pytest
from lerobot.policies.rtc.latency_tracker import LatencyTracker
# ====================== Fixtures ======================
@pytest.fixture
def tracker():
"""Create a LatencyTracker with default maxlen."""
return LatencyTracker(maxlen=100)
@pytest.fixture
def small_tracker():
"""Create a LatencyTracker with small maxlen for overflow testing."""
return LatencyTracker(maxlen=5)
# ====================== Initialization Tests ======================
def test_latency_tracker_initialization():
"""Test LatencyTracker initializes correctly."""
tracker = LatencyTracker(maxlen=50)
assert len(tracker) == 0
assert tracker.max_latency == 0.0
assert tracker.max() == 0.0
def test_latency_tracker_default_maxlen():
"""Test LatencyTracker uses default maxlen."""
tracker = LatencyTracker()
# Should accept default maxlen=100
assert len(tracker) == 0
# ====================== add() Tests ======================
def test_add_single_latency(tracker):
"""Test adding a single latency value."""
tracker.add(0.5)
assert len(tracker) == 1
assert tracker.max() == 0.5
def test_add_multiple_latencies(tracker):
"""Test adding multiple latency values."""
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
for lat in latencies:
tracker.add(lat)
assert len(tracker) == 5
assert tracker.max() == 0.8
def test_add_negative_latency_ignored(tracker):
"""Test that negative latencies are ignored."""
tracker.add(0.5)
tracker.add(-0.1)
tracker.add(0.3)
# Should only have 2 valid latencies
assert len(tracker) == 2
assert tracker.max() == 0.5
def test_add_zero_latency(tracker):
"""Test adding zero latency."""
tracker.add(0.0)
assert len(tracker) == 1
assert tracker.max() == 0.0
def test_add_converts_to_float(tracker):
"""Test add() converts input to float."""
tracker.add(5) # Integer
tracker.add("3.5") # String
assert len(tracker) == 2
assert tracker.max() == 5.0
def test_add_updates_max_latency(tracker):
"""Test that max_latency is updated correctly."""
tracker.add(0.5)
assert tracker.max_latency == 0.5
tracker.add(0.3)
assert tracker.max_latency == 0.5 # Should not decrease
tracker.add(0.9)
assert tracker.max_latency == 0.9 # Should increase
# ====================== reset() Tests ======================
def test_reset_clears_values(tracker):
"""Test reset() clears all values."""
tracker.add(0.5)
tracker.add(0.8)
tracker.add(0.3)
assert len(tracker) == 3
tracker.reset()
assert len(tracker) == 0
assert tracker.max_latency == 0.0
def test_reset_clears_max_latency(tracker):
"""Test reset() resets max_latency."""
tracker.add(1.5)
assert tracker.max_latency == 1.5
tracker.reset()
assert tracker.max_latency == 0.0
def test_reset_allows_new_values(tracker):
"""Test that tracker works correctly after reset."""
tracker.add(0.5)
tracker.reset()
tracker.add(0.3)
assert len(tracker) == 1
assert tracker.max() == 0.3
# ====================== max() Tests ======================
def test_max_returns_zero_when_empty(tracker):
"""Test max() returns 0.0 when tracker is empty."""
assert tracker.max() == 0.0
def test_max_returns_maximum_value(tracker):
"""Test max() returns the maximum latency."""
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
for lat in latencies:
tracker.add(lat)
assert tracker.max() == 0.8
def test_max_persists_after_sliding_window(small_tracker):
"""Test max() persists even after values slide out of window."""
# Add values that will exceed maxlen=5
small_tracker.add(0.1)
small_tracker.add(0.9) # This is max
small_tracker.add(0.2)
small_tracker.add(0.3)
small_tracker.add(0.4)
small_tracker.add(0.5) # This pushes out 0.1
# Max should still be 0.9 even though only last 5 values kept
assert small_tracker.max() == 0.9
def test_max_after_reset(tracker):
"""Test max() returns 0.0 after reset."""
tracker.add(1.5)
tracker.reset()
assert tracker.max() == 0.0
# ====================== p95() Tests ======================
def test_p95_returns_zero_when_empty(tracker):
"""Test p95() returns 0.0 when tracker is empty."""
assert tracker.p95() == 0.0
def test_p95_returns_95th_percentile(tracker):
"""Test p95() returns the 95th percentile."""
# Add 100 values
for i in range(100):
tracker.add(i / 100.0)
p95 = tracker.p95()
assert 0.93 <= p95 <= 0.96
def test_p95_equals_percentile_95(tracker):
"""Test p95() equals percentile(0.95)."""
for i in range(50):
tracker.add(i / 50.0)
assert tracker.p95() == tracker.percentile(0.95)
# ====================== Edge Cases Tests ======================
def test_single_value(tracker):
"""Test tracker behavior with single value."""
tracker.add(0.75)
assert len(tracker) == 1
assert tracker.max() == 0.75
assert tracker.percentile(0.0) == 0.75
assert tracker.percentile(0.5) == 0.75
assert tracker.percentile(1.0) == 0.75
def test_all_same_values(tracker):
"""Test tracker with all identical values."""
for _ in range(10):
tracker.add(0.5)
assert len(tracker) == 10
assert tracker.max() == 0.5
assert tracker.percentile(0.0) == 0.5
assert tracker.percentile(0.5) == 0.5
assert tracker.percentile(1.0) == 0.5
def test_very_small_values(tracker):
"""Test tracker with very small float values."""
tracker.add(1e-10)
tracker.add(2e-10)
tracker.add(3e-10)
assert len(tracker) == 3
assert tracker.max() == pytest.approx(3e-10)
def test_very_large_values(tracker):
"""Test tracker with very large float values."""
tracker.add(1e10)
tracker.add(2e10)
tracker.add(3e10)
assert len(tracker) == 3
assert tracker.max() == pytest.approx(3e10)
# ====================== Integration Tests ======================
def test_typical_usage_pattern(tracker):
"""Test a typical usage pattern of the tracker."""
# Simulate adding latencies over time
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
for lat in latencies:
tracker.add(lat)
# Check statistics
assert len(tracker) == 10
assert tracker.max() == 0.15
# p95 should be close to max since we have only 10 values
p95 = tracker.p95()
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
assert p95 <= tracker.max() # p95 should be <= max
def test_reset_and_reuse(tracker):
"""Test resetting and reusing tracker."""
# First batch
tracker.add(1.0)
tracker.add(2.0)
assert tracker.max() == 2.0
# Reset
tracker.reset()
# Second batch
tracker.add(0.5)
tracker.add(0.8)
assert len(tracker) == 2
assert tracker.max() == 0.8
assert tracker.percentile(0.5) <= 0.8
# ====================== Type Conversion Tests ======================
def test_add_with_integer(tracker):
"""Test adding integer values."""
tracker.add(5)
assert len(tracker) == 1
assert tracker.max() == 5.0
def test_add_with_string_number(tracker):
"""Test adding string representation of number."""
tracker.add("3.14")
assert len(tracker) == 1
assert tracker.max() == pytest.approx(3.14)
def test_percentile_converts_q_to_float(tracker):
"""Test percentile converts q parameter to float."""
tracker.add(0.5)
tracker.add(0.8)
# Pass integer q
result = tracker.percentile(1)
assert result == 0.8
+773
View File
@@ -0,0 +1,773 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC modeling module (RTCProcessor)."""
import pytest
import torch
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
# ====================== Fixtures ======================
@pytest.fixture
def rtc_config_debug_enabled():
"""Create RTC config with debug enabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=True,
debug_maxlen=100,
)
@pytest.fixture
def rtc_config_debug_disabled():
"""Create RTC config with debug disabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=False,
)
@pytest.fixture
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
"""Create RTCProcessor with debug enabled."""
return RTCProcessor(rtc_config_debug_enabled)
@pytest.fixture
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
"""Create RTCProcessor with debug disabled."""
return RTCProcessor(rtc_config_debug_disabled)
@pytest.fixture
def sample_x_t():
"""Create sample x_t tensor (batch, time, action_dim)."""
return torch.randn(1, 50, 6)
@pytest.fixture
def sample_prev_chunk():
"""Create sample previous chunk tensor."""
return torch.randn(1, 50, 6)
# ====================== Initialization Tests ======================
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
"""Test RTCProcessor initializes with debug tracker."""
processor = RTCProcessor(rtc_config_debug_enabled)
assert processor.rtc_config == rtc_config_debug_enabled
assert processor.tracker is not None
assert processor.tracker.enabled is True
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
"""Test RTCProcessor initializes without debug tracker."""
processor = RTCProcessor(rtc_config_debug_disabled)
assert processor.rtc_config == rtc_config_debug_disabled
assert processor.tracker is None
# ====================== Tracker Proxy Methods Tests ======================
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test track() forwards to tracker when enabled."""
rtc_processor_debug_enabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
guidance_weight=2.0,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
assert steps[0].time == 0.5
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
"""Test track() does nothing when tracker disabled."""
# Should not raise error
rtc_processor_debug_disabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
)
# Should return empty list
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test get_all_debug_steps() returns tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 2
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
"""Test get_all_debug_steps() returns empty list when disabled."""
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert steps == []
assert isinstance(steps, list)
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
"""Test is_debug_enabled() returns True when tracker enabled."""
assert rtc_processor_debug_enabled.is_debug_enabled() is True
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
"""Test is_debug_enabled() returns False when tracker disabled."""
assert rtc_processor_debug_disabled.is_debug_enabled() is False
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test reset_tracker() clears tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
rtc_processor_debug_enabled.reset_tracker()
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
"""Test reset_tracker() doesn't error when tracker disabled."""
rtc_processor_debug_disabled.reset_tracker() # Should not raise
# ====================== get_prefix_weights Tests ======================
def test_get_prefix_weights_zeros_schedule():
"""Test get_prefix_weights with ZEROS schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=10, total=20)
# First 5 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:5] == 1.0)
assert torch.all(weights[5:] == 0.0)
def test_get_prefix_weights_ones_schedule():
"""Test get_prefix_weights with ONES schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# First 15 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:15] == 1.0)
assert torch.all(weights[15:] == 0.0)
def test_get_prefix_weights_linear_schedule():
"""Test get_prefix_weights with LINEAR schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section (5:15) should be linearly decreasing from 1 to 0
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
assert torch.allclose(weights[5:14], middle_weights)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_exp_schedule():
"""Test get_prefix_weights with EXP schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section should be exponentially weighted
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_with_start_equals_end():
"""Test get_prefix_weights when start equals end."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_start_greater_than_end():
"""Test get_prefix_weights when start > end (gets clamped)."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
# start > end should use min(start, end) = end
weights = processor.get_prefix_weights(start=15, end=10, total=20)
# Should have ones up to end (10), then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
# ====================== Helper Method Tests ======================
def test_linweights_with_end_equals_start():
"""Test _linweights when end equals start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=10, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_linweights_with_end_less_than_start():
"""Test _linweights when end < start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=15, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_add_trailing_zeros_normal():
"""Test _add_trailing_zeros adds zeros correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
result = processor._add_trailing_zeros(weights, total=10, end=5)
# Should add 5 zeros (total - end = 10 - 5 = 5)
assert len(result) == 10
assert torch.all(result[:5] == weights)
assert torch.all(result[5:] == 0.0)
def test_add_trailing_zeros_no_zeros_needed():
"""Test _add_trailing_zeros when no zeros needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6])
result = processor._add_trailing_zeros(weights, total=3, end=5)
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
assert torch.equal(result, weights)
def test_add_leading_ones_normal():
"""Test _add_leading_ones adds ones correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
result = processor._add_leading_ones(weights, start=3, total=10)
# Should add 3 ones at the start
assert len(result) == 8
assert torch.all(result[:3] == 1.0)
assert torch.all(result[3:] == weights)
def test_add_leading_ones_no_ones_needed():
"""Test _add_leading_ones when no ones needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4])
result = processor._add_leading_ones(weights, start=0, total=10)
# ones_len = 0, so no ones added
assert torch.equal(result, weights)
def test_get_prefix_weights_with_start_equals_total():
"""Test get_prefix_weights when start equals total."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert len(weights) == 20
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_total_less_than_start():
"""Test get_prefix_weights when total less than start."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=5)
# Should have ones up to start, then zeros
assert len(weights) == 5
assert torch.all(weights == 1.0)
# ====================== denoise_step Tests ======================
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step without previous chunk (no guidance)."""
x_t = torch.randn(1, 50, 6)
# Mock denoiser that returns fixed velocity
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should return v_t unchanged (no guidance)
expected = mock_denoiser(x_t)
assert torch.allclose(result, expected)
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step with previous chunk applies guidance."""
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 20, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
expected_result = torch.tensor(
[
[
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.5833],
[1.3667],
[1.1500],
[0.9333],
[0.7167],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_adds_batch_dimension():
"""Test denoise_step handles 2D input by adding batch dimension."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
# 2D input (no batch dimension)
x_t = torch.randn(10, 6)
prev_chunk = torch.randn(5, 6)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Output should be 2D (batch dimension removed)
assert result.ndim == 2
assert result.shape == (10, 6)
def test_denoise_step_uses_custom_execution_horizon():
"""Test denoise_step uses custom execution_horizon parameter."""
config = RTCConfig(execution_horizon=10)
processor = RTCProcessor(config)
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 15, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
execution_horizon=15,
)
expected_result = torch.tensor(
[
[
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.8000],
[1.6818],
[1.5636],
[1.4455],
[1.3273],
[1.2091],
[1.0909],
[0.9727],
[0.8545],
[0.7364],
[0.6182],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_guidance_weight_at_time_zero():
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.ones(1, 20, 1)
prev_chunk = torch.full((1, 20, 1), 0.1)
def mock_denoiser(x):
return x * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.0),
original_denoise_step_partial=mock_denoiser,
)
expected_result = torch.tensor(
[
[
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
]
]
)
assert torch.allclose(result, expected_result, atol=1e-4)
def test_denoise_step_with_real_denoise_step_partial():
"""Test denoise_step with a real denoiser."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
batch_size = 10
action_dim = 6
chunk_size = 20
x_t = torch.ones(batch_size, chunk_size, action_dim)
prev_chunk = torch.full((batch_size, chunk_size, action_dim), 0.1)
velocity_function = torch.nn.Sequential(
torch.nn.Linear(action_dim, 1000),
torch.nn.ReLU(),
torch.nn.Linear(1000, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, action_dim),
)
def mock_denoiser(x):
return velocity_function(x)
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
assert result.shape == (batch_size, chunk_size, action_dim)
def test_denoise_step_guidance_weight_at_time_one():
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(1.0),
original_denoise_step_partial=mock_denoiser,
)
# Should clamp to max_guidance_weight (no Inf)
assert not torch.any(torch.isinf(result))
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
"""Test denoise_step tracks debug information when enabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_enabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
# Check tracked values
step = steps[0]
assert step.time == 0.5
assert step.x1_t is not None
assert step.correction is not None
assert step.err is not None
assert step.weights is not None
assert step.guidance_weight is not None
assert step.inference_delay == 5
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
"""Test denoise_step doesn't track when debug disabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should not track
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
# ====================== Integration Tests ======================
def test_denoise_step_full_workflow():
"""Test complete denoise_step workflow."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=5.0,
execution_horizon=10,
debug=True,
)
processor = RTCProcessor(config)
# Simulate two denoising steps
x_t1 = torch.randn(1, 50, 6)
x_t2 = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.randn_like(x) * 0.1
# First step - no guidance
result1 = processor.denoise_step(
x_t=x_t1,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.8),
original_denoise_step_partial=mock_denoiser,
)
# Second step - with guidance
result2 = processor.denoise_step(
x_t=x_t2,
prev_chunk_left_over=result1,
inference_delay=5,
time=torch.tensor(0.6),
original_denoise_step_partial=mock_denoiser,
)
# Both should complete successfully
assert result1.shape == (1, 50, 6)
assert result2.shape == (1, 50, 6)
# Should have tracked one step (second one, first had no prev_chunk)
steps = processor.get_all_debug_steps()
assert len(steps) == 1
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_denoise_step_with_cuda_tensors():
"""Test denoise_step works with CUDA tensors."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6, device="cuda")
prev_chunk = torch.randn(1, 50, 6, device="cuda")
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Result should be on CUDA
assert result.device.type == "cuda"
assert result.shape == x_t.shape
def test_denoise_step_deterministic_with_same_inputs():
"""Test denoise_step produces same output with same inputs."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
torch.manual_seed(42)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def deterministic_denoiser(x):
return torch.ones_like(x) * 0.5
result1 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
result2 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
# Should produce identical results
assert torch.allclose(result1, result2)
+309
View File
@@ -0,0 +1,309 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference."""
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_smolvla_rtc_initialization():
"""Test SmolVLA policy can initialize RTC processor."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Instantiate policy
policy = SmolVLAPolicy(config)
# Verify RTC processor is initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is not None
assert policy.rtc_processor.rtc_config.enabled is True
print("✓ SmolVLA RTC initialization: Test passed")
@require_cuda
def test_smolvla_rtc_initialization_without_rtc_config():
"""Test SmolVLA policy can initialize without RTC config."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Instantiate policy
policy = SmolVLAPolicy(config)
# Verify RTC processor is not initialized
assert hasattr(policy, "rtc_processor")
assert policy.rtc_processor is None
assert policy.model.rtc_processor is None
assert policy._rtc_enabled() is False
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk():
"""Test SmolVLA policy inference with RTC and previous chunk."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC and previous chunk
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=4,
execution_horizon=10,
)
# Test without RTC for comparison
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Verify shapes
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
# With previous chunk, actions should be different (RTC guidance applied)
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk():
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC enabled but no previous chunk
actions_with_rtc_no_prev = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=None,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
# Without previous chunk, RTC should have no effect
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules():
"""Test SmolVLA policy with RTC follows all three validation rules."""
set_seed(42)
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
# Add RTC config
config.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=False,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
# Create dataset stats
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
}
# Instantiate policy and create preprocessor
policy = SmolVLAPolicy(config)
policy.eval()
preprocessor, _ = make_pre_post_processors(
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
)
device = config.device
# Create dummy batch
batch = {
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
"task": ["Pick up the object"],
}
batch = preprocessor(batch)
# Create previous chunk
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
inference_delay = 4
execution_horizon = 10
with torch.no_grad():
# Use same noise for fair comparison
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
# Test with RTC
actions_with_rtc = policy.predict_action_chunk(
batch,
noise=noise.clone(),
prev_chunk_left_over=prev_chunk,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
# Test without RTC
policy.config.rtc_config.enabled = False
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
policy.config.rtc_config.enabled = True
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)