mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
8a915c6b6f
* Add Real-Time Chunking (RTC) support for flow matching models Implement Real-Time Chunking (RTC) for action chunking policies using flow matching denoising. RTC enables smooth action transitions between consecutive chunks by using prefix guidance during denoising. Key features: - RTCProcessor class with denoise_step method for RTC guidance - Tracker system for debug tracking using time-based dictionary storage - RTCDebugVisualizer with comprehensive visualization utilities - Integration with SmolVLA policy for flow matching models - Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP) - Configurable execution horizon and max guidance weight - Example scripts for dataset evaluation and real-time control Technical details: - Uses autograd-based gradient computation for RTC corrections - Time-based tracking eliminates duplicate step issues - Proxy methods in RTCProcessor for cleaner API - Full integration with LeRobot's policy and dataset systems Files added/modified: - src/lerobot/configs/types.py: Add RTCAttentionSchedule enum - src/lerobot/policies/rtc/: Core RTC implementation - configuration_rtc.py: RTC configuration - modeling_rtc.py: RTCProcessor with denoise_step - debug_handler.py: Tracker for debug information - debug_visualizer.py: Visualization utilities - src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration - examples/rtc/: Example scripts and evaluation tools 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Fix rtc_config attribute access in SmolVLA Use getattr() to safely check for rtc_config attribute existence instead of direct attribute access. This fixes AttributeError when loading policies without rtc_config in their config. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix rtc_config attribute access in SmolVLA * Add RTCConfig field to SmolVLAConfig Add rtc_config as an optional field in SmolVLAConfig to properly support Real-Time Chunking configuration. This replaces the previous getattr() workarounds with direct attribute access, making the code cleaner and more maintainable. Changes: - Import RTCConfig in configuration_smolvla.py - Add rtc_config: RTCConfig | None = None field - Revert getattr() calls to direct attribute access in modeling_smolvla.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Refactor RTC enabled checks to use _rtc_enabled helper Add _rtc_enabled() helper method in VLAFlowMatching class to simplify and clean up RTC enabled checks throughout the code. This reduces code duplication and improves readability. Changes: - Add _rtc_enabled() method in VLAFlowMatching - Replace verbose rtc_config checks with _rtc_enabled() calls - Maintain exact same functionality with cleaner code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Rename track_debug method to track Simplify the method name from track_debug to just track for better readability and consistency. The method already has clear documentation about its debug tracking purpose. Changes: - Rename RTCProcessor.track_debug() to track() - Update all call sites in modeling_smolvla.py and modeling_rtc.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * Use output_dir for saving all evaluation images Update eval_dataset.py to save all comparison images to the configured output_dir instead of the current directory. This provides better organization and allows users to specify where outputs should be saved. Changes: - Add os import at top level - Create output_dir at start of run_evaluation() - Save all comparison images to output_dir - Remove duplicate os imports - Update init_rtc_processor() docstring to be more concise 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Use output_dir for saving all evaluation images * Fix logging buffering and enable tracking when RTC config provided - Add force=True to logging.basicConfig to override existing configuration - Enable line buffering for stdout/stderr for real-time log output - Modify init_rtc_processor to create processor when rtc_config exists even if RTC is disabled, allowing tracking of denoising data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Refactor SmolVLA plotting to use tracker data instead of local variables Remove local tracking variables (correction, x1_t, error) from the denoising loop and instead retrieve plotting data from the RTC tracker after each denoise step. This makes the code cleaner and uses the tracker as the single source of truth for debug/visualization data. Changes: - Remove initialization of correction, x1_t, error before denoising loop - After each Euler step, retrieve most recent debug step from tracker - Extract correction, x1_t, err from debug step for plotting - Update tracking condition to use is_debug_enabled() method 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Move plotting logic from modeling_smolvla to eval_dataset script Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * Refactor plotting loging * fixup! Refactor plotting loging * Improve visualization: separate correction plot and fix axis scaling Changes: - Create separate figure for correction data instead of overlaying on v_t - Add _rescale_axes helper method to properly scale all axes - Add 10% margin to y-axis for better visualization - Fix v_t chart vertical compression issue Benefits: - Clearer v_t plot without correction overlay - Better axis scaling with proper margins - Separate correction figure for focused analysis - Improved readability of all denoising visualizations Output files: - denoising_xt_comparison.png (x_t trajectories) - denoising_vt_comparison.png (v_t velocity - now cleaner) - denoising_correction_comparison.png (NEW - separate corrections) - denoising_x1t_comparison.png (x1_t state with error) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> * fixup! Improve visualization: separate correction plot and fix axis scaling * fixup! fixup! Improve visualization: separate correction plot and fix axis scaling * fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling * Fix traacking * Right kwargs for the policy * Add tests for tracker * Fix tests * Drop not required methods * Add torch compilation for eval_dataset * delete policies * Add matplotliv to dev * fixup! Add matplotliv to dev * Experiemnt with late detach * Debug * Fix compilation * Add RTC to PI0 * Pi0 * Pi0 eval dataset * fixup! Pi0 eval dataset * Turn off compilation for pi0/pi05 * fixup! Turn off compilation for pi0/pi05 * fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 * fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 * Add workable flow * Small fixes * Add more tests * Add validatio at the end * Update README * Silent validation * Fix tests * Add tests for modeling_rtc * Add tests for flow matching models with RTC * fixup! Add tests for flow matching models with RTC * fixup! fixup! Add tests for flow matching models with RTC * Add one more test * fixup! Add one more test * Fix test to use _rtc_enabled() instead of is_rtc_enabled() 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() * fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() * Add RTC initialization tests without config for PI0.5 and SmolVLA Add test_pi05_rtc_initialization_without_rtc_config and test_smolvla_rtc_initialization_without_rtc_config to verify that policies can initialize without RTC config and that _rtc_enabled() returns False in this case. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix PI0.5 init_rtc_processor to use getattr instead of direct model access 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix SmolVLA init_rtc_processor to use getattr instead of direct model access 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization * Fixup eval with real robot * fixup! Fixup eval with real robot * fixup! fixup! Fixup eval with real robot * Extract simulator logic from eval_with real robot and add proper headers to files * Update images * Fix tests * fixup! Fix tests * add docs for rtc * enhance doc and add images * Fix instal instructions --------- Co-authored-by: Ben Zhang <benzhangniu@gmail.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
323 lines
8.1 KiB
Python
323 lines
8.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for RTC LatencyTracker module."""
|
|
|
|
import pytest
|
|
|
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
|
|
|
# ====================== Fixtures ======================
|
|
|
|
|
|
@pytest.fixture
|
|
def tracker():
|
|
"""Create a LatencyTracker with default maxlen."""
|
|
return LatencyTracker(maxlen=100)
|
|
|
|
|
|
@pytest.fixture
|
|
def small_tracker():
|
|
"""Create a LatencyTracker with small maxlen for overflow testing."""
|
|
return LatencyTracker(maxlen=5)
|
|
|
|
|
|
# ====================== Initialization Tests ======================
|
|
|
|
|
|
def test_latency_tracker_initialization():
|
|
"""Test LatencyTracker initializes correctly."""
|
|
tracker = LatencyTracker(maxlen=50)
|
|
assert len(tracker) == 0
|
|
assert tracker.max_latency == 0.0
|
|
assert tracker.max() == 0.0
|
|
|
|
|
|
def test_latency_tracker_default_maxlen():
|
|
"""Test LatencyTracker uses default maxlen."""
|
|
tracker = LatencyTracker()
|
|
# Should accept default maxlen=100
|
|
assert len(tracker) == 0
|
|
|
|
|
|
# ====================== add() Tests ======================
|
|
|
|
|
|
def test_add_single_latency(tracker):
|
|
"""Test adding a single latency value."""
|
|
tracker.add(0.5)
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == 0.5
|
|
|
|
|
|
def test_add_multiple_latencies(tracker):
|
|
"""Test adding multiple latency values."""
|
|
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
|
|
for lat in latencies:
|
|
tracker.add(lat)
|
|
|
|
assert len(tracker) == 5
|
|
assert tracker.max() == 0.8
|
|
|
|
|
|
def test_add_negative_latency_ignored(tracker):
|
|
"""Test that negative latencies are ignored."""
|
|
tracker.add(0.5)
|
|
tracker.add(-0.1)
|
|
tracker.add(0.3)
|
|
|
|
# Should only have 2 valid latencies
|
|
assert len(tracker) == 2
|
|
assert tracker.max() == 0.5
|
|
|
|
|
|
def test_add_zero_latency(tracker):
|
|
"""Test adding zero latency."""
|
|
tracker.add(0.0)
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == 0.0
|
|
|
|
|
|
def test_add_converts_to_float(tracker):
|
|
"""Test add() converts input to float."""
|
|
tracker.add(5) # Integer
|
|
tracker.add("3.5") # String
|
|
|
|
assert len(tracker) == 2
|
|
assert tracker.max() == 5.0
|
|
|
|
|
|
def test_add_updates_max_latency(tracker):
|
|
"""Test that max_latency is updated correctly."""
|
|
tracker.add(0.5)
|
|
assert tracker.max_latency == 0.5
|
|
|
|
tracker.add(0.3)
|
|
assert tracker.max_latency == 0.5 # Should not decrease
|
|
|
|
tracker.add(0.9)
|
|
assert tracker.max_latency == 0.9 # Should increase
|
|
|
|
|
|
# ====================== reset() Tests ======================
|
|
|
|
|
|
def test_reset_clears_values(tracker):
|
|
"""Test reset() clears all values."""
|
|
tracker.add(0.5)
|
|
tracker.add(0.8)
|
|
tracker.add(0.3)
|
|
assert len(tracker) == 3
|
|
|
|
tracker.reset()
|
|
assert len(tracker) == 0
|
|
assert tracker.max_latency == 0.0
|
|
|
|
|
|
def test_reset_clears_max_latency(tracker):
|
|
"""Test reset() resets max_latency."""
|
|
tracker.add(1.5)
|
|
assert tracker.max_latency == 1.5
|
|
|
|
tracker.reset()
|
|
assert tracker.max_latency == 0.0
|
|
|
|
|
|
def test_reset_allows_new_values(tracker):
|
|
"""Test that tracker works correctly after reset."""
|
|
tracker.add(0.5)
|
|
tracker.reset()
|
|
|
|
tracker.add(0.3)
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == 0.3
|
|
|
|
|
|
# ====================== max() Tests ======================
|
|
|
|
|
|
def test_max_returns_zero_when_empty(tracker):
|
|
"""Test max() returns 0.0 when tracker is empty."""
|
|
assert tracker.max() == 0.0
|
|
|
|
|
|
def test_max_returns_maximum_value(tracker):
|
|
"""Test max() returns the maximum latency."""
|
|
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
|
|
for lat in latencies:
|
|
tracker.add(lat)
|
|
|
|
assert tracker.max() == 0.8
|
|
|
|
|
|
def test_max_persists_after_sliding_window(small_tracker):
|
|
"""Test max() persists even after values slide out of window."""
|
|
# Add values that will exceed maxlen=5
|
|
small_tracker.add(0.1)
|
|
small_tracker.add(0.9) # This is max
|
|
small_tracker.add(0.2)
|
|
small_tracker.add(0.3)
|
|
small_tracker.add(0.4)
|
|
small_tracker.add(0.5) # This pushes out 0.1
|
|
|
|
# Max should still be 0.9 even though only last 5 values kept
|
|
assert small_tracker.max() == 0.9
|
|
|
|
|
|
def test_max_after_reset(tracker):
|
|
"""Test max() returns 0.0 after reset."""
|
|
tracker.add(1.5)
|
|
tracker.reset()
|
|
assert tracker.max() == 0.0
|
|
|
|
|
|
# ====================== p95() Tests ======================
|
|
|
|
|
|
def test_p95_returns_zero_when_empty(tracker):
|
|
"""Test p95() returns 0.0 when tracker is empty."""
|
|
assert tracker.p95() == 0.0
|
|
|
|
|
|
def test_p95_returns_95th_percentile(tracker):
|
|
"""Test p95() returns the 95th percentile."""
|
|
# Add 100 values
|
|
for i in range(100):
|
|
tracker.add(i / 100.0)
|
|
|
|
p95 = tracker.p95()
|
|
assert 0.93 <= p95 <= 0.96
|
|
|
|
|
|
def test_p95_equals_percentile_95(tracker):
|
|
"""Test p95() equals percentile(0.95)."""
|
|
for i in range(50):
|
|
tracker.add(i / 50.0)
|
|
|
|
assert tracker.p95() == tracker.percentile(0.95)
|
|
|
|
|
|
# ====================== Edge Cases Tests ======================
|
|
|
|
|
|
def test_single_value(tracker):
|
|
"""Test tracker behavior with single value."""
|
|
tracker.add(0.75)
|
|
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == 0.75
|
|
assert tracker.percentile(0.0) == 0.75
|
|
assert tracker.percentile(0.5) == 0.75
|
|
assert tracker.percentile(1.0) == 0.75
|
|
|
|
|
|
def test_all_same_values(tracker):
|
|
"""Test tracker with all identical values."""
|
|
for _ in range(10):
|
|
tracker.add(0.5)
|
|
|
|
assert len(tracker) == 10
|
|
assert tracker.max() == 0.5
|
|
assert tracker.percentile(0.0) == 0.5
|
|
assert tracker.percentile(0.5) == 0.5
|
|
assert tracker.percentile(1.0) == 0.5
|
|
|
|
|
|
def test_very_small_values(tracker):
|
|
"""Test tracker with very small float values."""
|
|
tracker.add(1e-10)
|
|
tracker.add(2e-10)
|
|
tracker.add(3e-10)
|
|
|
|
assert len(tracker) == 3
|
|
assert tracker.max() == pytest.approx(3e-10)
|
|
|
|
|
|
def test_very_large_values(tracker):
|
|
"""Test tracker with very large float values."""
|
|
tracker.add(1e10)
|
|
tracker.add(2e10)
|
|
tracker.add(3e10)
|
|
|
|
assert len(tracker) == 3
|
|
assert tracker.max() == pytest.approx(3e10)
|
|
|
|
|
|
# ====================== Integration Tests ======================
|
|
|
|
|
|
def test_typical_usage_pattern(tracker):
|
|
"""Test a typical usage pattern of the tracker."""
|
|
# Simulate adding latencies over time
|
|
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
|
|
|
|
for lat in latencies:
|
|
tracker.add(lat)
|
|
|
|
# Check statistics
|
|
assert len(tracker) == 10
|
|
assert tracker.max() == 0.15
|
|
|
|
# p95 should be close to max since we have only 10 values
|
|
p95 = tracker.p95()
|
|
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
|
|
assert p95 <= tracker.max() # p95 should be <= max
|
|
|
|
|
|
def test_reset_and_reuse(tracker):
|
|
"""Test resetting and reusing tracker."""
|
|
# First batch
|
|
tracker.add(1.0)
|
|
tracker.add(2.0)
|
|
assert tracker.max() == 2.0
|
|
|
|
# Reset
|
|
tracker.reset()
|
|
|
|
# Second batch
|
|
tracker.add(0.5)
|
|
tracker.add(0.8)
|
|
assert len(tracker) == 2
|
|
assert tracker.max() == 0.8
|
|
assert tracker.percentile(0.5) <= 0.8
|
|
|
|
|
|
# ====================== Type Conversion Tests ======================
|
|
|
|
|
|
def test_add_with_integer(tracker):
|
|
"""Test adding integer values."""
|
|
tracker.add(5)
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == 5.0
|
|
|
|
|
|
def test_add_with_string_number(tracker):
|
|
"""Test adding string representation of number."""
|
|
tracker.add("3.14")
|
|
assert len(tracker) == 1
|
|
assert tracker.max() == pytest.approx(3.14)
|
|
|
|
|
|
def test_percentile_converts_q_to_float(tracker):
|
|
"""Test percentile converts q parameter to float."""
|
|
tracker.add(0.5)
|
|
tracker.add(0.8)
|
|
|
|
# Pass integer q
|
|
result = tracker.percentile(1)
|
|
assert result == 0.8
|