From 99eea2ae034b09ea3315310ced8fac59646142eb Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Sat, 8 Nov 2025 17:07:45 +0700 Subject: [PATCH] Add more tests --- src/lerobot/policies/rtc/modeling_rtc.py | 7 +- tests/policies/rtc/test_action_queue.py | 825 +++++++++++++++++++ tests/policies/rtc/test_configuration_rtc.py | 323 ++++++++ tests/policies/rtc/test_debug_visualizer.py | 427 ++++++++++ tests/policies/rtc/test_latency_tracker.py | 481 +++++++++++ tests/policies/rtc/test_modeling_rtc.py | 784 ++++++++++++++++++ 6 files changed, 2844 insertions(+), 3 deletions(-) create mode 100644 tests/policies/rtc/test_action_queue.py create mode 100644 tests/policies/rtc/test_configuration_rtc.py create mode 100644 tests/policies/rtc/test_debug_visualizer.py create mode 100644 tests/policies/rtc/test_latency_tracker.py create mode 100644 tests/policies/rtc/test_modeling_rtc.py diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 1994c76ac..0445aa982 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -217,9 +217,10 @@ class RTCProcessor: correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) - squared_one_minus_tau = (1 - tau) ** 2 - inv_r2 = (squared_one_minus_tau + tau**2) / (squared_one_minus_tau) - c = torch.nan_to_num((1 - tau) / tau, posinf=max_guidance_weight) + tau_tensor = torch.as_tensor(tau) + squared_one_minus_tau = (1 - tau_tensor) ** 2 + inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau) + c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight) guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight) guidance_weight = torch.minimum(guidance_weight, max_guidance_weight) diff --git a/tests/policies/rtc/test_action_queue.py b/tests/policies/rtc/test_action_queue.py new file mode 100644 index 000000000..2f9b84384 --- /dev/null +++ b/tests/policies/rtc/test_action_queue.py @@ -0,0 +1,825 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RTC ActionQueue module.""" + +import threading +import time + +import pytest +import torch + +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig + +# ====================== Fixtures ====================== + + +@pytest.fixture +def rtc_config_enabled(): + """Create an RTC config with RTC enabled.""" + return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0) + + +@pytest.fixture +def rtc_config_disabled(): + """Create an RTC config with RTC disabled.""" + return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0) + + +@pytest.fixture +def sample_actions(): + """Create sample action tensors for testing.""" + return { + "original": torch.randn(50, 6), # (time_steps, action_dim) + "processed": torch.randn(50, 6), + "short": torch.randn(10, 6), + "longer": torch.randn(100, 6), + } + + +@pytest.fixture +def action_queue_rtc_enabled(rtc_config_enabled): + """Create an ActionQueue with RTC enabled.""" + return ActionQueue(rtc_config_enabled) + + +@pytest.fixture +def action_queue_rtc_disabled(rtc_config_disabled): + """Create an ActionQueue with RTC disabled.""" + return ActionQueue(rtc_config_disabled) + + +# ====================== Initialization Tests ====================== + + +def test_action_queue_initialization_rtc_enabled(rtc_config_enabled): + """Test ActionQueue initializes correctly with RTC enabled.""" + queue = ActionQueue(rtc_config_enabled) + assert queue.queue is None + assert queue.original_queue is None + assert queue.last_index == 0 + assert queue.cfg.enabled is True + + +def test_action_queue_initialization_rtc_disabled(rtc_config_disabled): + """Test ActionQueue initializes correctly with RTC disabled.""" + queue = ActionQueue(rtc_config_disabled) + assert queue.queue is None + assert queue.original_queue is None + assert queue.last_index == 0 + assert queue.cfg.enabled is False + + +# ====================== get() Tests ====================== + + +def test_get_returns_none_when_empty(action_queue_rtc_enabled): + """Test get() returns None when queue is empty.""" + action = action_queue_rtc_enabled.get() + assert action is None + + +def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions): + """Test get() returns actions in sequence.""" + # Initialize queue with actions + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + # Get first action + action1 = action_queue_rtc_enabled.get() + assert action1 is not None + assert action1.shape == (6,) + assert torch.equal(action1, sample_actions["processed"][0]) + + # Get second action + action2 = action_queue_rtc_enabled.get() + assert action2 is not None + assert torch.equal(action2, sample_actions["processed"][1]) + + +def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions): + """Test get() returns None after all actions are consumed.""" + # Use short action sequence + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume all actions + for _ in range(10): + action = action_queue_rtc_enabled.get() + assert action is not None + + # Next get should return None + action = action_queue_rtc_enabled.get() + assert action is None + + +def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions): + """Test get() increments last_index correctly.""" + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + assert action_queue_rtc_enabled.last_index == 0 + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.last_index == 1 + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.last_index == 2 + + +# ====================== qsize() Tests ====================== + + +def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled): + """Test qsize() returns 0 when queue is empty.""" + assert action_queue_rtc_enabled.qsize() == 0 + + +def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions): + """Test qsize() returns correct number of remaining actions.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + assert action_queue_rtc_enabled.qsize() == 10 + + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.qsize() == 9 + + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.qsize() == 8 + + +def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions): + """Test qsize() returns 0 after queue is exhausted.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume all actions + for _ in range(10): + action_queue_rtc_enabled.get() + + assert action_queue_rtc_enabled.qsize() == 0 + + +# ====================== empty() Tests ====================== + + +def test_empty_returns_true_when_empty(action_queue_rtc_enabled): + """Test empty() returns True when queue is empty.""" + assert action_queue_rtc_enabled.empty() is True + + +def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions): + """Test empty() returns False when queue has actions.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + assert action_queue_rtc_enabled.empty() is False + + +def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions): + """Test empty() returns False after partial consumption.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + + assert action_queue_rtc_enabled.empty() is False + + +def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions): + """Test empty() returns True after all actions consumed.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume all + for _ in range(10): + action_queue_rtc_enabled.get() + + assert action_queue_rtc_enabled.empty() is True + + +# ====================== get_action_index() Tests ====================== + + +def test_get_action_index_initial_value(action_queue_rtc_enabled): + """Test get_action_index() returns 0 initially.""" + assert action_queue_rtc_enabled.get_action_index() == 0 + + +def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions): + """Test get_action_index() tracks consumption correctly.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + assert action_queue_rtc_enabled.get_action_index() == 0 + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.get_action_index() == 1 + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.get_action_index() == 3 + + +# ====================== get_left_over() Tests ====================== + + +def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled): + """Test get_left_over() returns None when queue is empty.""" + leftover = action_queue_rtc_enabled.get_left_over() + assert leftover is None + + +def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions): + """Test get_left_over() returns all original actions when none consumed.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + leftover = action_queue_rtc_enabled.get_left_over() + assert leftover is not None + assert leftover.shape == (10, 6) + assert torch.equal(leftover, sample_actions["short"]) + + +def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions): + """Test get_left_over() returns only remaining original actions.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume 3 actions + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + + leftover = action_queue_rtc_enabled.get_left_over() + assert leftover is not None + assert leftover.shape == (7, 6) + assert torch.equal(leftover, sample_actions["short"][3:]) + + +def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions): + """Test get_left_over() returns empty tensor after all consumed.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume all + for _ in range(10): + action_queue_rtc_enabled.get() + + leftover = action_queue_rtc_enabled.get_left_over() + assert leftover is not None + assert leftover.shape == (0, 6) + + +# ====================== merge() with RTC Enabled Tests ====================== + + +def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions): + """Test merge() replaces queue when RTC is enabled.""" + # Add initial actions + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + assert action_queue_rtc_enabled.qsize() == 10 + + # Consume some actions + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.qsize() == 8 + + # Merge new actions - should replace, not append + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5) + + # Queue should be replaced with new actions minus delay + # Original has 50 actions, delay is 5, so remaining is 45 + assert action_queue_rtc_enabled.qsize() == 45 + assert action_queue_rtc_enabled.get_action_index() == 0 + + +def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions): + """Test merge() correctly applies real_delay when RTC is enabled.""" + delay = 10 + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay) + + # Queue should have original length minus delay + expected_size = len(sample_actions["original"]) - delay + assert action_queue_rtc_enabled.qsize() == expected_size + + # First action should be the one at index [delay] + first_action = action_queue_rtc_enabled.get() + assert torch.equal(first_action, sample_actions["processed"][delay]) + + +def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions): + """Test merge() resets last_index to 0 when RTC is enabled.""" + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + action_queue_rtc_enabled.get() + action_queue_rtc_enabled.get() + assert action_queue_rtc_enabled.last_index == 2 + + # Merge new actions + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5) + + assert action_queue_rtc_enabled.last_index == 0 + + +def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions): + """Test merge() with zero delay keeps all actions.""" + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"]) + + +def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions): + """Test merge() with delay larger than action sequence.""" + # Delay is larger than sequence length + delay = 100 + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay) + + # Queue should be empty (delay >= length) + assert action_queue_rtc_enabled.qsize() == 0 + + +# ====================== merge() with RTC Disabled Tests ====================== + + +def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions): + """Test merge() appends actions when RTC is disabled.""" + # Add initial actions + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + initial_size = action_queue_rtc_disabled.qsize() + assert initial_size == 10 + + # Merge more actions + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Should have appended + assert action_queue_rtc_disabled.qsize() == initial_size + 10 + + +def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions): + """Test merge() removes consumed actions before appending when RTC is disabled.""" + # Add initial actions + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + assert action_queue_rtc_disabled.qsize() == 10 + + # Consume 3 actions + action_queue_rtc_disabled.get() + action_queue_rtc_disabled.get() + action_queue_rtc_disabled.get() + assert action_queue_rtc_disabled.qsize() == 7 + + # Merge more actions + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Should have 7 remaining + 10 new = 17 + assert action_queue_rtc_disabled.qsize() == 17 + + +def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions): + """Test merge() resets last_index after appending when RTC is disabled.""" + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + action_queue_rtc_disabled.get() + action_queue_rtc_disabled.get() + assert action_queue_rtc_disabled.last_index == 2 + + # Merge more actions + action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # last_index should be reset to 0 + assert action_queue_rtc_disabled.last_index == 0 + + +def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions): + """Test merge() ignores real_delay parameter when RTC is disabled.""" + action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10) + + # All actions should be in queue (delay ignored) + assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"]) + + +def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions): + """Test merge() on first call with RTC disabled.""" + action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"]) + assert action_queue_rtc_disabled.last_index == 0 + + +# ====================== merge() with Different Action Shapes Tests ====================== + + +def test_merge_with_different_action_dims(): + """Test merge() handles actions with different dimensions.""" + cfg = RTCConfig(enabled=True, execution_horizon=10) + queue = ActionQueue(cfg) + + # Actions with 4 dimensions instead of 6 + actions_4d = torch.randn(20, 4) + queue.merge(actions_4d, actions_4d, real_delay=5) + + action = queue.get() + assert action.shape == (4,) + + +def test_merge_with_different_lengths(): + """Test merge() handles action sequences of varying lengths.""" + cfg = RTCConfig(enabled=False, execution_horizon=10) + queue = ActionQueue(cfg) + + # Add sequences of different lengths + queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0) + assert queue.qsize() == 10 + + queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0) + assert queue.qsize() == 35 + + +# ====================== merge() Delay Validation Tests ====================== + + +def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog): + """Test merge() validates that real_delay matches action index difference.""" + import logging + + caplog.set_level(logging.WARNING) + + # Initialize queue + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume 5 actions + for _ in range(5): + action_queue_rtc_enabled.get() + + # Merge with mismatched delay (should log warning) + # We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0, + # then indexes_diff=5, but if real_delay=3, it will warn + action_queue_rtc_enabled.merge( + sample_actions["original"], + sample_actions["processed"], + real_delay=3, + action_index_before_inference=0, + ) + + # Check warning was logged + assert "Indexes diff is not equal to real delay" in caplog.text + + +def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog): + """Test merge() doesn't warn when delays are consistent.""" + import logging + + caplog.set_level(logging.WARNING) + + # Initialize queue + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume 5 actions + for _ in range(5): + action_queue_rtc_enabled.get() + + # Merge with matching delay + action_queue_rtc_enabled.merge( + sample_actions["original"], + sample_actions["processed"], + real_delay=5, + action_index_before_inference=0, + ) + + # Should not have warning + assert "Indexes diff is not equal to real delay" not in caplog.text + + +def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog): + """Test merge() skips delay validation when action_index_before_inference is None.""" + import logging + + caplog.set_level(logging.WARNING) + + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + for _ in range(5): + action_queue_rtc_enabled.get() + + # Pass None for action_index_before_inference + action_queue_rtc_enabled.merge( + sample_actions["original"], + sample_actions["processed"], + real_delay=999, # Doesn't matter + action_index_before_inference=None, + ) + + # Should not warn (validation skipped) + assert "Indexes diff is not equal to real delay" not in caplog.text + + +# ====================== Thread Safety Tests ====================== + + +def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions): + """Test get() is thread-safe with multiple consumers.""" + action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0) + + results = [] + errors = [] + + def consumer(): + try: + for _ in range(25): + action = action_queue_rtc_enabled.get() + if action is not None: + results.append(action) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=consumer) for _ in range(4)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Should not have errors + assert len(errors) == 0 + + # Should have consumed all actions (100 total, 4 threads * 25 each) + assert len(results) == 100 + + # All results should be unique (no duplicate consumption) + # We can verify by checking that indices are not duplicated + # Since we don't track indices in results, we check total count is correct + assert action_queue_rtc_enabled.qsize() == 0 + + +def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions): + """Test merge() is thread-safe with multiple producers.""" + errors = [] + + def producer(): + try: + for _ in range(5): + action_queue_rtc_disabled.merge( + sample_actions["short"], sample_actions["short"], real_delay=0 + ) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=producer) for _ in range(3)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Should not have errors + assert len(errors) == 0 + + # Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150) + assert action_queue_rtc_disabled.qsize() == 150 + + +def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions): + """Test concurrent get() and merge() operations.""" + errors = [] + consumed_count = [0] + + def consumer(): + try: + for _ in range(50): + action = action_queue_rtc_disabled.get() + if action is not None: + consumed_count[0] += 1 + time.sleep(0.001) + except Exception as e: + errors.append(e) + + def producer(): + try: + for _ in range(10): + action_queue_rtc_disabled.merge( + sample_actions["short"], sample_actions["short"], real_delay=0 + ) + time.sleep(0.005) + except Exception as e: + errors.append(e) + + consumer_threads = [threading.Thread(target=consumer) for _ in range(2)] + producer_threads = [threading.Thread(target=producer) for _ in range(2)] + + for t in consumer_threads + producer_threads: + t.start() + + for t in consumer_threads + producer_threads: + t.join() + + # Should not have errors + assert len(errors) == 0 + + # Should have consumed some or all actions (non-deterministic due to timing) + # Total produced: 2 producers * 10 merges * 10 actions = 200 + # Total consumed attempts: 2 consumers * 50 = 100 + assert consumed_count[0] <= 200 + + +# ====================== get_left_over() Thread Safety Tests ====================== + + +def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions): + """Test get_left_over() is thread-safe with concurrent access.""" + action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0) + + errors = [] + leftovers = [] + + def reader(): + try: + for _ in range(20): + leftover = action_queue_rtc_enabled.get_left_over() + if leftover is not None: + leftovers.append(leftover.shape[0]) + time.sleep(0.001) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=reader) for _ in range(3)] + + # Also consume some actions concurrently + def consumer(): + try: + for _ in range(10): + action_queue_rtc_enabled.get() + time.sleep(0.002) + except Exception as e: + errors.append(e) + + consumer_thread = threading.Thread(target=consumer) + + all_threads = threads + [consumer_thread] + + for t in all_threads: + t.start() + + for t in all_threads: + t.join() + + # Should not have errors + assert len(errors) == 0 + + # Leftovers should be monotonically decreasing or stable + # (as actions are consumed, leftover size decreases) + assert len(leftovers) > 0 + + +# ====================== Edge Cases Tests ====================== + + +def test_queue_with_single_action(action_queue_rtc_enabled): + """Test queue behavior with a single action.""" + single_action_original = torch.randn(1, 6) + single_action_processed = torch.randn(1, 6) + + action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0) + + assert action_queue_rtc_enabled.qsize() == 1 + action = action_queue_rtc_enabled.get() + assert action is not None + assert action.shape == (6,) + assert action_queue_rtc_enabled.qsize() == 0 + + +def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions): + """Test queue maintains correct state through multiple merge cycles.""" + for _ in range(5): + action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0) + + # Consume half + for _ in range(5): + action_queue_rtc_enabled.get() + + # Merge again + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3) + + assert action_queue_rtc_enabled.qsize() > 0 + + +def test_queue_with_all_zeros_actions(action_queue_rtc_enabled): + """Test queue handles all-zero action tensors.""" + zeros_actions = torch.zeros(20, 6) + action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0) + + action = action_queue_rtc_enabled.get() + assert torch.all(action == 0) + + +def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions): + """Test that merge() clones input tensors, not storing references.""" + original_copy = sample_actions["original"].clone() + processed_copy = sample_actions["processed"].clone() + + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + # Modify original tensors + sample_actions["original"].fill_(999.0) + sample_actions["processed"].fill_(-999.0) + + # Queue should have cloned values + action = action_queue_rtc_enabled.get() + assert not torch.equal(action, sample_actions["processed"][0]) + assert torch.equal(action, processed_copy[0]) + + leftover = action_queue_rtc_enabled.get_left_over() + assert not torch.equal(leftover, sample_actions["original"][1:]) + assert torch.equal(leftover, original_copy[1:]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_queue_handles_gpu_tensors(): + """Test queue correctly handles GPU tensors.""" + cfg = RTCConfig(enabled=True, execution_horizon=10) + queue = ActionQueue(cfg) + + actions_gpu = torch.randn(20, 6, device="cuda") + queue.merge(actions_gpu, actions_gpu, real_delay=0) + + action = queue.get() + assert action.device.type == "cuda" + + leftover = queue.get_left_over() + assert leftover.device.type == "cuda" + + +def test_queue_handles_different_dtypes(): + """Test queue handles actions with different dtypes.""" + cfg = RTCConfig(enabled=True, execution_horizon=10) + queue = ActionQueue(cfg) + + # Use float64 instead of default float32 + actions_f64 = torch.randn(20, 6, dtype=torch.float64) + queue.merge(actions_f64, actions_f64, real_delay=0) + + action = queue.get() + assert action.dtype == torch.float64 + + +def test_empty_with_none_queue(action_queue_rtc_enabled): + """Test empty() correctly handles None queue.""" + assert action_queue_rtc_enabled.queue is None + assert action_queue_rtc_enabled.empty() is True + + +def test_qsize_with_none_queue(action_queue_rtc_enabled): + """Test qsize() correctly handles None queue.""" + assert action_queue_rtc_enabled.queue is None + assert action_queue_rtc_enabled.qsize() == 0 + + +# ====================== Integration Tests ====================== + + +def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions): + """Test a typical RTC workflow: merge, consume, merge with delay.""" + # First inference + action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + initial_size = action_queue_rtc_enabled.qsize() + assert initial_size == 50 + + # Consume 10 actions (execution_horizon) + for _ in range(10): + action = action_queue_rtc_enabled.get() + assert action is not None + + assert action_queue_rtc_enabled.qsize() == 40 + + # Second inference with delay + action_index_before = action_queue_rtc_enabled.get_action_index() + + action_queue_rtc_enabled.merge( + sample_actions["original"], + sample_actions["processed"], + real_delay=5, + action_index_before_inference=action_index_before, + ) + + # Queue should be replaced, minus delay + assert action_queue_rtc_enabled.qsize() == 45 + assert action_queue_rtc_enabled.get_action_index() == 0 + + +def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions): + """Test a typical non-RTC workflow: merge, consume, merge again.""" + # First inference + action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + assert action_queue_rtc_disabled.qsize() == 50 + + # Consume 40 actions + for _ in range(40): + action = action_queue_rtc_disabled.get() + assert action is not None + + assert action_queue_rtc_disabled.qsize() == 10 + + # Second inference (should append) + action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0) + + # Should have 10 remaining + 50 new = 60 + assert action_queue_rtc_disabled.qsize() == 60 diff --git a/tests/policies/rtc/test_configuration_rtc.py b/tests/policies/rtc/test_configuration_rtc.py new file mode 100644 index 000000000..2251e007c --- /dev/null +++ b/tests/policies/rtc/test_configuration_rtc.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RTC configuration module.""" + +import pytest + +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.policies.rtc.configuration_rtc import RTCConfig + +# ====================== Initialization Tests ====================== + + +def test_rtc_config_default_initialization(): + """Test RTCConfig initializes with default values.""" + config = RTCConfig() + + assert config.enabled is False + assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR + assert config.max_guidance_weight == 10.0 + assert config.execution_horizon == 10 + assert config.debug is False + assert config.debug_maxlen == 100 + + +def test_rtc_config_custom_initialization(): + """Test RTCConfig initializes with custom values.""" + config = RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + max_guidance_weight=5.0, + execution_horizon=20, + debug=True, + debug_maxlen=200, + ) + + assert config.enabled is True + assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP + assert config.max_guidance_weight == 5.0 + assert config.execution_horizon == 20 + assert config.debug is True + assert config.debug_maxlen == 200 + + +def test_rtc_config_partial_initialization(): + """Test RTCConfig with partial custom values.""" + config = RTCConfig(enabled=True, max_guidance_weight=15.0) + + assert config.enabled is True + assert config.max_guidance_weight == 15.0 + # Other values should be defaults + assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR + assert config.execution_horizon == 10 + assert config.debug is False + + +# ====================== Validation Tests ====================== + + +def test_rtc_config_validates_positive_max_guidance_weight(): + """Test RTCConfig validates max_guidance_weight is positive.""" + with pytest.raises(ValueError, match="max_guidance_weight must be positive"): + RTCConfig(max_guidance_weight=0.0) + + with pytest.raises(ValueError, match="max_guidance_weight must be positive"): + RTCConfig(max_guidance_weight=-1.0) + + +def test_rtc_config_validates_positive_debug_maxlen(): + """Test RTCConfig validates debug_maxlen is positive.""" + with pytest.raises(ValueError, match="debug_maxlen must be positive"): + RTCConfig(debug_maxlen=0) + + with pytest.raises(ValueError, match="debug_maxlen must be positive"): + RTCConfig(debug_maxlen=-10) + + +def test_rtc_config_accepts_valid_max_guidance_weight(): + """Test RTCConfig accepts valid positive max_guidance_weight.""" + config1 = RTCConfig(max_guidance_weight=0.1) + assert config1.max_guidance_weight == 0.1 + + config2 = RTCConfig(max_guidance_weight=100.0) + assert config2.max_guidance_weight == 100.0 + + +def test_rtc_config_accepts_valid_debug_maxlen(): + """Test RTCConfig accepts valid positive debug_maxlen.""" + config1 = RTCConfig(debug_maxlen=1) + assert config1.debug_maxlen == 1 + + config2 = RTCConfig(debug_maxlen=10000) + assert config2.debug_maxlen == 10000 + + +# ====================== Attention Schedule Tests ====================== + + +def test_rtc_config_with_linear_schedule(): + """Test RTCConfig with LINEAR attention schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR + + +def test_rtc_config_with_exp_schedule(): + """Test RTCConfig with EXP attention schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP) + assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP + + +def test_rtc_config_with_zeros_schedule(): + """Test RTCConfig with ZEROS attention schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS) + assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS + + +def test_rtc_config_with_ones_schedule(): + """Test RTCConfig with ONES attention schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES) + assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES + + +# ====================== Enabled/Disabled Tests ====================== + + +def test_rtc_config_enabled_true(): + """Test RTCConfig with enabled=True.""" + config = RTCConfig(enabled=True) + assert config.enabled is True + + +def test_rtc_config_enabled_false(): + """Test RTCConfig with enabled=False.""" + config = RTCConfig(enabled=False) + assert config.enabled is False + + +# ====================== Debug Tests ====================== + + +def test_rtc_config_debug_enabled(): + """Test RTCConfig with debug enabled.""" + config = RTCConfig(debug=True, debug_maxlen=500) + assert config.debug is True + assert config.debug_maxlen == 500 + + +def test_rtc_config_debug_disabled(): + """Test RTCConfig with debug disabled.""" + config = RTCConfig(debug=False) + assert config.debug is False + + +# ====================== Execution Horizon Tests ====================== + + +def test_rtc_config_with_small_execution_horizon(): + """Test RTCConfig with small execution horizon.""" + config = RTCConfig(execution_horizon=1) + assert config.execution_horizon == 1 + + +def test_rtc_config_with_large_execution_horizon(): + """Test RTCConfig with large execution horizon.""" + config = RTCConfig(execution_horizon=100) + assert config.execution_horizon == 100 + + +def test_rtc_config_with_zero_execution_horizon(): + """Test RTCConfig accepts zero execution horizon.""" + # No validation on execution_horizon, so this should work + config = RTCConfig(execution_horizon=0) + assert config.execution_horizon == 0 + + +def test_rtc_config_with_negative_execution_horizon(): + """Test RTCConfig accepts negative execution horizon.""" + # No validation on execution_horizon, so this should work + config = RTCConfig(execution_horizon=-1) + assert config.execution_horizon == -1 + + +# ====================== Integration Tests ====================== + + +def test_rtc_config_typical_production_settings(): + """Test RTCConfig with typical production settings.""" + config = RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + max_guidance_weight=10.0, + execution_horizon=8, + debug=False, + ) + + assert config.enabled is True + assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP + assert config.max_guidance_weight == 10.0 + assert config.execution_horizon == 8 + assert config.debug is False + + +def test_rtc_config_typical_debug_settings(): + """Test RTCConfig with typical debug settings.""" + config = RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, + max_guidance_weight=5.0, + execution_horizon=10, + debug=True, + debug_maxlen=1000, + ) + + assert config.enabled is True + assert config.debug is True + assert config.debug_maxlen == 1000 + + +def test_rtc_config_disabled_mode(): + """Test RTCConfig in disabled mode.""" + config = RTCConfig(enabled=False) + + assert config.enabled is False + # Other settings still accessible even when disabled + assert config.max_guidance_weight == 10.0 + assert config.execution_horizon == 10 + + +# ====================== Dataclass Tests ====================== + + +def test_rtc_config_is_dataclass(): + """Test that RTCConfig is a dataclass.""" + from dataclasses import is_dataclass + + assert is_dataclass(RTCConfig) + + +def test_rtc_config_equality(): + """Test RTCConfig equality comparison.""" + config1 = RTCConfig(enabled=True, max_guidance_weight=5.0) + config2 = RTCConfig(enabled=True, max_guidance_weight=5.0) + config3 = RTCConfig(enabled=False, max_guidance_weight=5.0) + + assert config1 == config2 + assert config1 != config3 + + +def test_rtc_config_repr(): + """Test RTCConfig string representation.""" + config = RTCConfig(enabled=True, execution_horizon=20) + repr_str = repr(config) + + assert "RTCConfig" in repr_str + assert "enabled=True" in repr_str + assert "execution_horizon=20" in repr_str + + +# ====================== Edge Cases Tests ====================== + + +def test_rtc_config_very_small_max_guidance_weight(): + """Test RTCConfig with very small positive max_guidance_weight.""" + config = RTCConfig(max_guidance_weight=1e-10) + assert config.max_guidance_weight == pytest.approx(1e-10) + + +def test_rtc_config_very_large_max_guidance_weight(): + """Test RTCConfig with very large max_guidance_weight.""" + config = RTCConfig(max_guidance_weight=1e10) + assert config.max_guidance_weight == pytest.approx(1e10) + + +def test_rtc_config_minimum_debug_maxlen(): + """Test RTCConfig with minimum valid debug_maxlen.""" + config = RTCConfig(debug_maxlen=1) + assert config.debug_maxlen == 1 + + +def test_rtc_config_float_max_guidance_weight(): + """Test RTCConfig with float max_guidance_weight.""" + config = RTCConfig(max_guidance_weight=3.14159) + assert config.max_guidance_weight == pytest.approx(3.14159) + + +# ====================== Type Tests ====================== + + +def test_rtc_config_enabled_type(): + """Test RTCConfig enabled field accepts boolean.""" + config = RTCConfig(enabled=True) + assert isinstance(config.enabled, bool) + + +def test_rtc_config_execution_horizon_type(): + """Test RTCConfig execution_horizon field accepts integer.""" + config = RTCConfig(execution_horizon=15) + assert isinstance(config.execution_horizon, int) + + +def test_rtc_config_max_guidance_weight_type(): + """Test RTCConfig max_guidance_weight field accepts float.""" + config = RTCConfig(max_guidance_weight=7.5) + assert isinstance(config.max_guidance_weight, float) + + +def test_rtc_config_debug_maxlen_type(): + """Test RTCConfig debug_maxlen field accepts integer.""" + config = RTCConfig(debug_maxlen=200) + assert isinstance(config.debug_maxlen, int) diff --git a/tests/policies/rtc/test_debug_visualizer.py b/tests/policies/rtc/test_debug_visualizer.py new file mode 100644 index 000000000..41b2926fe --- /dev/null +++ b/tests/policies/rtc/test_debug_visualizer.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RTC debug visualizer module.""" + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer + +# ====================== Fixtures ====================== + + +@pytest.fixture +def mock_axes(): + """Create mock matplotlib axes.""" + axes = [] + for _ in range(6): + ax = MagicMock() + ax.xaxis.get_label.return_value.get_text.return_value = "" + ax.yaxis.get_label.return_value.get_text.return_value = "" + axes.append(ax) + return axes + + +@pytest.fixture +def sample_tensor_2d(): + """Create a 2D sample tensor (time_steps, num_dims).""" + return torch.randn(50, 6) + + +@pytest.fixture +def sample_tensor_3d(): + """Create a 3D sample tensor (batch, time_steps, num_dims).""" + return torch.randn(1, 50, 6) + + +@pytest.fixture +def sample_numpy_2d(): + """Create a 2D numpy array.""" + return np.random.randn(50, 6) + + +# ====================== Basic Plotting Tests ====================== + + +def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d): + """Test plot_waypoints with 2D tensor.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) + + # Should call plot on each axis (6 dimensions) + for ax in mock_axes: + ax.plot.assert_called_once() + + +def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d): + """Test plot_waypoints with 3D tensor (batch dimension).""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d) + + # Should still plot 6 dimensions (batch dimension removed) + for ax in mock_axes: + ax.plot.assert_called_once() + + +def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d): + """Test plot_waypoints with numpy array.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d) + + # Should work with numpy arrays + for ax in mock_axes: + ax.plot.assert_called_once() + + +def test_plot_waypoints_with_none_tensor(mock_axes): + """Test plot_waypoints returns early when tensor is None.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, None) + + # Should not call plot on any axis + for ax in mock_axes: + ax.plot.assert_not_called() + + +# ====================== Parameter Tests ====================== + + +def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d): + """Test plot_waypoints uses custom color.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red") + + # Check that color was passed to plot + for ax in mock_axes: + call_kwargs = ax.plot.call_args[1] + assert call_kwargs["color"] == "red" + + +def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d): + """Test plot_waypoints uses custom label.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label") + + # First axis should have label, others should not + first_ax_kwargs = mock_axes[0].plot.call_args[1] + assert first_ax_kwargs["label"] == "test_label" + + # Other axes should have empty label + for ax in mock_axes[1:]: + call_kwargs = ax.plot.call_args[1] + assert call_kwargs["label"] == "" + + +def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d): + """Test plot_waypoints uses custom alpha.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5) + + for ax in mock_axes: + call_kwargs = ax.plot.call_args[1] + assert call_kwargs["alpha"] == 0.5 + + +def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d): + """Test plot_waypoints uses custom linewidth.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3) + + for ax in mock_axes: + call_kwargs = ax.plot.call_args[1] + assert call_kwargs["linewidth"] == 3 + + +def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d): + """Test plot_waypoints with marker style.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5) + + for ax in mock_axes: + call_kwargs = ax.plot.call_args[1] + assert call_kwargs["marker"] == "o" + assert call_kwargs["markersize"] == 5 + + +def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d): + """Test plot_waypoints without marker (default).""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None) + + # Marker should not be in kwargs when None + for ax in mock_axes: + call_kwargs = ax.plot.call_args[1] + assert "marker" not in call_kwargs + assert "markersize" not in call_kwargs + + +# ====================== start_from Parameter Tests ====================== + + +def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d): + """Test plot_waypoints with start_from=0.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0) + + # X indices should start from 0 + for ax in mock_axes: + call_args = ax.plot.call_args[0] + x_indices = call_args[0] + assert x_indices[0] == 0 + assert len(x_indices) == 50 + + +def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d): + """Test plot_waypoints with start_from > 0.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10) + + # X indices should start from 10 + for ax in mock_axes: + call_args = ax.plot.call_args[0] + x_indices = call_args[0] + assert x_indices[0] == 10 + assert x_indices[-1] == 59 # 10 + 50 - 1 + + +# ====================== Tensor Shape Tests ====================== + + +def test_plot_waypoints_with_1d_tensor(mock_axes): + """Test plot_waypoints with 1D tensor.""" + tensor_1d = torch.randn(50) + RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d) + + # Should reshape to (50, 1) and plot on first axis only + mock_axes[0].plot.assert_called_once() + for ax in mock_axes[1:]: + ax.plot.assert_not_called() + + +def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d): + """Test plot_waypoints when tensor has fewer dims than axes.""" + # Create tensor with only 3 dimensions + tensor_3d = sample_tensor_2d[:, :3] + + # Create 6 axes but tensor only has 3 dims + mock_axes = [MagicMock() for _ in range(6)] + for ax in mock_axes: + ax.xaxis.get_label.return_value.get_text.return_value = "" + ax.yaxis.get_label.return_value.get_text.return_value = "" + + RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d) + + # Should only plot on first 3 axes + for i in range(3): + mock_axes[i].plot.assert_called_once() + for i in range(3, 6): + mock_axes[i].plot.assert_not_called() + + +# ====================== Axis Labeling Tests ====================== + + +def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d): + """Test plot_waypoints sets x-axis label.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) + + for ax in mock_axes: + ax.set_xlabel.assert_called_once_with("Step", fontsize=10) + + +def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d): + """Test plot_waypoints sets y-axis label.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) + + for i, ax in enumerate(mock_axes): + ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10) + + +def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d): + """Test plot_waypoints doesn't set labels if they already exist.""" + mock_axes_with_labels = [] + for _ in range(6): + ax = MagicMock() + # Simulate existing labels + ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label" + ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label" + mock_axes_with_labels.append(ax) + + RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d) + + # Should not set labels when they already exist + for ax in mock_axes_with_labels: + ax.set_xlabel.assert_not_called() + ax.set_ylabel.assert_not_called() + + +# ====================== Grid Tests ====================== + + +def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d): + """Test plot_waypoints enables grid on all axes.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) + + for ax in mock_axes: + ax.grid.assert_called_once_with(True, alpha=0.3) + + +# ====================== Legend Tests ====================== + + +def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d): + """Test plot_waypoints adds legend when label is provided.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label") + + # Should add legend to first axis only + mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8) + + # Should not add legend to other axes + for ax in mock_axes[1:]: + ax.legend.assert_not_called() + + +def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d): + """Test plot_waypoints doesn't add legend when no label provided.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="") + + # Should not add legend to any axis + for ax in mock_axes: + ax.legend.assert_not_called() + + +# ====================== Data Correctness Tests ====================== + + +def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d): + """Test plot_waypoints plots correct tensor values.""" + RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0) + + # Check first axis to verify data correctness + call_args = mock_axes[0].plot.call_args[0] + x_indices = call_args[0] + y_values = call_args[1] + + # X indices should be 0 to 49 + np.testing.assert_array_equal(x_indices, np.arange(50)) + + # Y values should match first dimension of tensor + expected_y = sample_tensor_2d[:, 0].cpu().numpy() + np.testing.assert_array_almost_equal(y_values, expected_y) + + +def test_plot_waypoints_handles_gpu_tensor(mock_axes): + """Test plot_waypoints handles GPU tensors (if CUDA available).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + tensor_gpu = torch.randn(50, 6, device="cuda") + RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu) + + # Should successfully plot without errors + for ax in mock_axes: + ax.plot.assert_called_once() + + +# ====================== Edge Cases Tests ====================== + + +def test_plot_waypoints_with_empty_tensor(mock_axes): + """Test plot_waypoints with empty tensor.""" + empty_tensor = torch.empty(0, 6) + RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor) + + # Should plot empty data + for ax in mock_axes: + call_args = ax.plot.call_args[0] + x_indices = call_args[0] + assert len(x_indices) == 0 + + +def test_plot_waypoints_with_single_timestep(mock_axes): + """Test plot_waypoints with single timestep.""" + single_step_tensor = torch.randn(1, 6) + RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor) + + # Should plot single point + for ax in mock_axes: + call_args = ax.plot.call_args[0] + x_indices = call_args[0] + assert len(x_indices) == 1 + + +def test_plot_waypoints_with_very_large_tensor(mock_axes): + """Test plot_waypoints with very large tensor.""" + large_tensor = torch.randn(10000, 6) + RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor) + + # Should handle large tensors + for ax in mock_axes: + call_args = ax.plot.call_args[0] + x_indices = call_args[0] + assert len(x_indices) == 10000 + + +# ====================== Multiple Calls Tests ====================== + + +def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d): + """Test multiple plot_waypoints calls on same axes.""" + tensor1 = sample_tensor_2d + tensor2 = torch.randn(50, 6) + + RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1") + RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2") + + # Each axis should have been called twice + for ax in mock_axes: + assert ax.plot.call_count == 2 + + +# ====================== Integration Tests ====================== + + +def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d): + """Test plot_waypoints with typical usage pattern.""" + RTCDebugVisualizer.plot_waypoints( + mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2 + ) + + # Verify all expected calls were made + for ax in mock_axes: + ax.plot.assert_called_once() + ax.set_xlabel.assert_called_once() + ax.set_ylabel.assert_called_once() + ax.grid.assert_called_once() + + # First axis should have legend + mock_axes[0].legend.assert_called_once() + + +def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d): + """Test plot_waypoints with all parameters specified.""" + RTCDebugVisualizer.plot_waypoints( + axes=mock_axes, + tensor=sample_tensor_2d, + start_from=10, + color="green", + label="Full Test", + alpha=0.8, + linewidth=3, + marker="o", + markersize=6, + ) + + # Check first axis for all parameters + call_kwargs = mock_axes[0].plot.call_args[1] + assert call_kwargs["color"] == "green" + assert call_kwargs["label"] == "Full Test" + assert call_kwargs["alpha"] == 0.8 + assert call_kwargs["linewidth"] == 3 + assert call_kwargs["marker"] == "o" + assert call_kwargs["markersize"] == 6 diff --git a/tests/policies/rtc/test_latency_tracker.py b/tests/policies/rtc/test_latency_tracker.py new file mode 100644 index 000000000..af6b89431 --- /dev/null +++ b/tests/policies/rtc/test_latency_tracker.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RTC LatencyTracker module.""" + +import pytest + +from lerobot.policies.rtc.latency_tracker import LatencyTracker + +# ====================== Fixtures ====================== + + +@pytest.fixture +def tracker(): + """Create a LatencyTracker with default maxlen.""" + return LatencyTracker(maxlen=100) + + +@pytest.fixture +def small_tracker(): + """Create a LatencyTracker with small maxlen for overflow testing.""" + return LatencyTracker(maxlen=5) + + +# ====================== Initialization Tests ====================== + + +def test_latency_tracker_initialization(): + """Test LatencyTracker initializes correctly.""" + tracker = LatencyTracker(maxlen=50) + assert len(tracker) == 0 + assert tracker.max_latency == 0.0 + assert tracker.max() == 0.0 + + +def test_latency_tracker_default_maxlen(): + """Test LatencyTracker uses default maxlen.""" + tracker = LatencyTracker() + # Should accept default maxlen=100 + assert len(tracker) == 0 + + +# ====================== add() Tests ====================== + + +def test_add_single_latency(tracker): + """Test adding a single latency value.""" + tracker.add(0.5) + assert len(tracker) == 1 + assert tracker.max() == 0.5 + + +def test_add_multiple_latencies(tracker): + """Test adding multiple latency values.""" + latencies = [0.1, 0.5, 0.3, 0.8, 0.2] + for lat in latencies: + tracker.add(lat) + + assert len(tracker) == 5 + assert tracker.max() == 0.8 + + +def test_add_negative_latency_ignored(tracker): + """Test that negative latencies are ignored.""" + tracker.add(0.5) + tracker.add(-0.1) + tracker.add(0.3) + + # Should only have 2 valid latencies + assert len(tracker) == 2 + assert tracker.max() == 0.5 + + +def test_add_zero_latency(tracker): + """Test adding zero latency.""" + tracker.add(0.0) + assert len(tracker) == 1 + assert tracker.max() == 0.0 + + +def test_add_converts_to_float(tracker): + """Test add() converts input to float.""" + tracker.add(5) # Integer + tracker.add("3.5") # String + + assert len(tracker) == 2 + assert tracker.max() == 5.0 + + +def test_add_updates_max_latency(tracker): + """Test that max_latency is updated correctly.""" + tracker.add(0.5) + assert tracker.max_latency == 0.5 + + tracker.add(0.3) + assert tracker.max_latency == 0.5 # Should not decrease + + tracker.add(0.9) + assert tracker.max_latency == 0.9 # Should increase + + +# ====================== reset() Tests ====================== + + +def test_reset_clears_values(tracker): + """Test reset() clears all values.""" + tracker.add(0.5) + tracker.add(0.8) + tracker.add(0.3) + assert len(tracker) == 3 + + tracker.reset() + assert len(tracker) == 0 + assert tracker.max_latency == 0.0 + + +def test_reset_clears_max_latency(tracker): + """Test reset() resets max_latency.""" + tracker.add(1.5) + assert tracker.max_latency == 1.5 + + tracker.reset() + assert tracker.max_latency == 0.0 + + +def test_reset_allows_new_values(tracker): + """Test that tracker works correctly after reset.""" + tracker.add(0.5) + tracker.reset() + + tracker.add(0.3) + assert len(tracker) == 1 + assert tracker.max() == 0.3 + + +# ====================== max() Tests ====================== + + +def test_max_returns_zero_when_empty(tracker): + """Test max() returns 0.0 when tracker is empty.""" + assert tracker.max() == 0.0 + + +def test_max_returns_maximum_value(tracker): + """Test max() returns the maximum latency.""" + latencies = [0.2, 0.8, 0.3, 0.5, 0.1] + for lat in latencies: + tracker.add(lat) + + assert tracker.max() == 0.8 + + +def test_max_persists_after_sliding_window(small_tracker): + """Test max() persists even after values slide out of window.""" + # Add values that will exceed maxlen=5 + small_tracker.add(0.1) + small_tracker.add(0.9) # This is max + small_tracker.add(0.2) + small_tracker.add(0.3) + small_tracker.add(0.4) + small_tracker.add(0.5) # This pushes out 0.1 + + # Max should still be 0.9 even though only last 5 values kept + assert small_tracker.max() == 0.9 + + +def test_max_after_reset(tracker): + """Test max() returns 0.0 after reset.""" + tracker.add(1.5) + tracker.reset() + assert tracker.max() == 0.0 + + +# ====================== percentile() Tests ====================== + + +def test_percentile_returns_zero_when_empty(tracker): + """Test percentile() returns 0.0 when tracker is empty.""" + assert tracker.percentile(0.5) == 0.0 + assert tracker.percentile(0.95) == 0.0 + + +def test_percentile_median(tracker): + """Test percentile(0.5) returns median.""" + # Add sorted values for easier verification + values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + for v in values: + tracker.add(v) + + # Median should be around 0.5 + median = tracker.percentile(0.5) + assert 0.45 <= median <= 0.55 + + +def test_percentile_minimum_with_zero(tracker): + """Test percentile(0.0) returns minimum.""" + tracker.add(0.5) + tracker.add(0.2) + tracker.add(0.8) + + assert tracker.percentile(0.0) == 0.2 + + +def test_percentile_maximum_with_one(tracker): + """Test percentile(1.0) returns maximum.""" + tracker.add(0.5) + tracker.add(0.2) + tracker.add(0.8) + + assert tracker.percentile(1.0) == 0.8 + + +def test_percentile_95(tracker): + """Test percentile(0.95) returns 95th percentile.""" + # Add 100 values from 0.0 to 0.99 + for i in range(100): + tracker.add(i / 100.0) + + p95 = tracker.percentile(0.95) + # 95th percentile should be around 0.95 + assert 0.93 <= p95 <= 0.96 + + +def test_percentile_negative_value_returns_min(tracker): + """Test percentile with negative q returns minimum.""" + tracker.add(0.5) + tracker.add(0.2) + tracker.add(0.8) + + assert tracker.percentile(-0.5) == 0.2 + + +def test_percentile_value_greater_than_one_returns_max(tracker): + """Test percentile with q > 1.0 returns maximum.""" + tracker.add(0.5) + tracker.add(0.2) + tracker.add(0.8) + + assert tracker.percentile(1.5) == 0.8 + + +# ====================== p95() Tests ====================== + + +def test_p95_returns_zero_when_empty(tracker): + """Test p95() returns 0.0 when tracker is empty.""" + assert tracker.p95() == 0.0 + + +def test_p95_returns_95th_percentile(tracker): + """Test p95() returns the 95th percentile.""" + # Add 100 values + for i in range(100): + tracker.add(i / 100.0) + + p95 = tracker.p95() + assert 0.93 <= p95 <= 0.96 + + +def test_p95_equals_percentile_95(tracker): + """Test p95() equals percentile(0.95).""" + for i in range(50): + tracker.add(i / 50.0) + + assert tracker.p95() == tracker.percentile(0.95) + + +# ====================== __len__() Tests ====================== + + +def test_len_returns_zero_initially(tracker): + """Test __len__ returns 0 for new tracker.""" + assert len(tracker) == 0 + + +def test_len_increments_with_add(tracker): + """Test __len__ increments as values are added.""" + assert len(tracker) == 0 + + tracker.add(0.1) + assert len(tracker) == 1 + + tracker.add(0.2) + assert len(tracker) == 2 + + tracker.add(0.3) + assert len(tracker) == 3 + + +def test_len_respects_maxlen(small_tracker): + """Test __len__ respects maxlen limit.""" + # Add more than maxlen values + for i in range(10): + small_tracker.add(i / 10.0) + + # Should only keep last 5 + assert len(small_tracker) == 5 + + +def test_len_after_reset(tracker): + """Test __len__ returns 0 after reset.""" + tracker.add(0.5) + tracker.add(0.3) + assert len(tracker) == 2 + + tracker.reset() + assert len(tracker) == 0 + + +# ====================== Sliding Window Tests ====================== + + +def test_sliding_window_removes_oldest(small_tracker): + """Test sliding window removes oldest values.""" + # Add 7 values to tracker with maxlen=5 + values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + for v in values: + small_tracker.add(v) + + # Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7] + assert len(small_tracker) == 5 + + # Median should reflect last 5 values + median = small_tracker.percentile(0.5) + assert 0.45 <= median <= 0.55 + + +def test_sliding_window_maintains_max(small_tracker): + """Test sliding window maintains correct max even after overflow.""" + small_tracker.add(0.1) + small_tracker.add(0.9) + small_tracker.add(0.2) + small_tracker.add(0.3) + small_tracker.add(0.4) + small_tracker.add(0.5) # Pushes out 0.1 + + # Max should still be 0.9 + assert small_tracker.max() == 0.9 + + +# ====================== Edge Cases Tests ====================== + + +def test_single_value(tracker): + """Test tracker behavior with single value.""" + tracker.add(0.75) + + assert len(tracker) == 1 + assert tracker.max() == 0.75 + assert tracker.percentile(0.0) == 0.75 + assert tracker.percentile(0.5) == 0.75 + assert tracker.percentile(1.0) == 0.75 + + +def test_all_same_values(tracker): + """Test tracker with all identical values.""" + for _ in range(10): + tracker.add(0.5) + + assert len(tracker) == 10 + assert tracker.max() == 0.5 + assert tracker.percentile(0.0) == 0.5 + assert tracker.percentile(0.5) == 0.5 + assert tracker.percentile(1.0) == 0.5 + + +def test_very_small_values(tracker): + """Test tracker with very small float values.""" + tracker.add(1e-10) + tracker.add(2e-10) + tracker.add(3e-10) + + assert len(tracker) == 3 + assert tracker.max() == pytest.approx(3e-10) + + +def test_very_large_values(tracker): + """Test tracker with very large float values.""" + tracker.add(1e10) + tracker.add(2e10) + tracker.add(3e10) + + assert len(tracker) == 3 + assert tracker.max() == pytest.approx(3e10) + + +# ====================== Integration Tests ====================== + + +def test_typical_usage_pattern(tracker): + """Test a typical usage pattern of the tracker.""" + # Simulate adding latencies over time + latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10] + + for lat in latencies: + tracker.add(lat) + + # Check statistics + assert len(tracker) == 10 + assert tracker.max() == 0.15 + + # p95 should be close to max since we have only 10 values + p95 = tracker.p95() + assert p95 >= tracker.percentile(0.5) # p95 should be >= median + assert p95 <= tracker.max() # p95 should be <= max + + +def test_reset_and_reuse(tracker): + """Test resetting and reusing tracker.""" + # First batch + tracker.add(1.0) + tracker.add(2.0) + assert tracker.max() == 2.0 + + # Reset + tracker.reset() + + # Second batch + tracker.add(0.5) + tracker.add(0.8) + assert len(tracker) == 2 + assert tracker.max() == 0.8 + assert tracker.percentile(0.5) <= 0.8 + + +def test_continuous_monitoring(small_tracker): + """Test continuous monitoring with sliding window.""" + # Simulate continuous latency monitoring + # First 5 latencies + for i in range(5): + small_tracker.add(0.1 * (i + 1)) + + max_before = small_tracker.max() + + # Add 5 more (window slides) + for i in range(5, 10): + small_tracker.add(0.1 * (i + 1)) + + # Max should have increased + assert small_tracker.max() > max_before + assert len(small_tracker) == 5 # Window size maintained + + +# ====================== Type Conversion Tests ====================== + + +def test_add_with_integer(tracker): + """Test adding integer values.""" + tracker.add(5) + assert len(tracker) == 1 + assert tracker.max() == 5.0 + + +def test_add_with_string_number(tracker): + """Test adding string representation of number.""" + tracker.add("3.14") + assert len(tracker) == 1 + assert tracker.max() == pytest.approx(3.14) + + +def test_percentile_converts_q_to_float(tracker): + """Test percentile converts q parameter to float.""" + tracker.add(0.5) + tracker.add(0.8) + + # Pass integer q + result = tracker.percentile(1) + assert result == 0.8 diff --git a/tests/policies/rtc/test_modeling_rtc.py b/tests/policies/rtc/test_modeling_rtc.py new file mode 100644 index 000000000..52940c6d3 --- /dev/null +++ b/tests/policies/rtc/test_modeling_rtc.py @@ -0,0 +1,784 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RTC modeling module (RTCProcessor).""" + +import pytest +import torch + +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.modeling_rtc import RTCProcessor + +# ====================== Fixtures ====================== + + +@pytest.fixture +def rtc_config_debug_enabled(): + """Create RTC config with debug enabled.""" + return RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, + max_guidance_weight=10.0, + execution_horizon=10, + debug=True, + debug_maxlen=100, + ) + + +@pytest.fixture +def rtc_config_debug_disabled(): + """Create RTC config with debug disabled.""" + return RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, + max_guidance_weight=10.0, + execution_horizon=10, + debug=False, + ) + + +@pytest.fixture +def rtc_processor_debug_enabled(rtc_config_debug_enabled): + """Create RTCProcessor with debug enabled.""" + return RTCProcessor(rtc_config_debug_enabled) + + +@pytest.fixture +def rtc_processor_debug_disabled(rtc_config_debug_disabled): + """Create RTCProcessor with debug disabled.""" + return RTCProcessor(rtc_config_debug_disabled) + + +@pytest.fixture +def sample_x_t(): + """Create sample x_t tensor (batch, time, action_dim).""" + return torch.randn(1, 50, 6) + + +@pytest.fixture +def sample_prev_chunk(): + """Create sample previous chunk tensor.""" + return torch.randn(1, 50, 6) + + +# ====================== Initialization Tests ====================== + + +def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled): + """Test RTCProcessor initializes with debug tracker.""" + processor = RTCProcessor(rtc_config_debug_enabled) + assert processor.rtc_config == rtc_config_debug_enabled + assert processor.tracker is not None + assert processor.tracker.enabled is True + + +def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled): + """Test RTCProcessor initializes without debug tracker.""" + processor = RTCProcessor(rtc_config_debug_disabled) + assert processor.rtc_config == rtc_config_debug_disabled + assert processor.tracker is None + + +# ====================== Tracker Proxy Methods Tests ====================== + + +def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t): + """Test track() forwards to tracker when enabled.""" + rtc_processor_debug_enabled.track( + time=torch.tensor(0.5), + x_t=sample_x_t, + v_t=sample_x_t, + guidance_weight=2.0, + ) + + # Should have tracked one step + steps = rtc_processor_debug_enabled.get_all_debug_steps() + assert len(steps) == 1 + assert steps[0].time == 0.5 + + +def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t): + """Test track() does nothing when tracker disabled.""" + # Should not raise error + rtc_processor_debug_disabled.track( + time=torch.tensor(0.5), + x_t=sample_x_t, + v_t=sample_x_t, + ) + + # Should return empty list + steps = rtc_processor_debug_disabled.get_all_debug_steps() + assert len(steps) == 0 + + +def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t): + """Test get_all_debug_steps() returns tracked steps.""" + rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t) + rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t) + + steps = rtc_processor_debug_enabled.get_all_debug_steps() + assert len(steps) == 2 + + +def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled): + """Test get_all_debug_steps() returns empty list when disabled.""" + steps = rtc_processor_debug_disabled.get_all_debug_steps() + assert steps == [] + assert isinstance(steps, list) + + +def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled): + """Test is_debug_enabled() returns True when tracker enabled.""" + assert rtc_processor_debug_enabled.is_debug_enabled() is True + + +def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled): + """Test is_debug_enabled() returns False when tracker disabled.""" + assert rtc_processor_debug_disabled.is_debug_enabled() is False + + +def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t): + """Test reset_tracker() clears tracked steps.""" + rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t) + rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t) + assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2 + + rtc_processor_debug_enabled.reset_tracker() + assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0 + + +def test_reset_tracker_when_disabled(rtc_processor_debug_disabled): + """Test reset_tracker() doesn't error when tracker disabled.""" + rtc_processor_debug_disabled.reset_tracker() # Should not raise + + +# ====================== get_prefix_weights Tests ====================== + + +def test_get_prefix_weights_zeros_schedule(): + """Test get_prefix_weights with ZEROS schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=5, end=10, total=20) + + # First 5 should be 1.0, rest should be 0.0 + assert weights.shape == (20,) + assert torch.all(weights[:5] == 1.0) + assert torch.all(weights[5:] == 0.0) + + +def test_get_prefix_weights_ones_schedule(): + """Test get_prefix_weights with ONES schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=5, end=15, total=20) + + # First 15 should be 1.0, rest should be 0.0 + assert weights.shape == (20,) + assert torch.all(weights[:15] == 1.0) + assert torch.all(weights[15:] == 0.0) + + +def test_get_prefix_weights_linear_schedule(): + """Test get_prefix_weights with LINEAR schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=5, end=15, total=20) + + # Should have shape (20,) + assert weights.shape == (20,) + + # First 5 should be 1.0 (leading ones) + assert torch.all(weights[:5] == 1.0) + + # Middle section (5:15) should be linearly decreasing from 1 to 0 + middle_weights = weights[5:15] + assert middle_weights[0] > middle_weights[-1] # Decreasing + assert torch.all(middle_weights >= 0.0) + assert torch.all(middle_weights <= 1.0) + + # Last 5 should be 0.0 (trailing zeros) + assert torch.all(weights[15:] == 0.0) + + +def test_get_prefix_weights_exp_schedule(): + """Test get_prefix_weights with EXP schedule.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=5, end=15, total=20) + + # Should have shape (20,) + assert weights.shape == (20,) + + # First 5 should be 1.0 (leading ones) + assert torch.all(weights[:5] == 1.0) + + # Middle section should be exponentially weighted + middle_weights = weights[5:15] + assert torch.all(middle_weights >= 0.0) + assert torch.all(middle_weights <= 1.0) + + # Last 5 should be 0.0 (trailing zeros) + assert torch.all(weights[15:] == 0.0) + + +def test_get_prefix_weights_with_start_equals_end(): + """Test get_prefix_weights when start equals end.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=10, end=10, total=20) + + # Should have ones up to start, then zeros + assert torch.all(weights[:10] == 1.0) + assert torch.all(weights[10:] == 0.0) + + +def test_get_prefix_weights_with_start_greater_than_end(): + """Test get_prefix_weights when start > end (gets clamped).""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + processor = RTCProcessor(config) + + # start > end should use min(start, end) = end + weights = processor.get_prefix_weights(start=15, end=10, total=20) + + # Should have ones up to end (10), then zeros + assert torch.all(weights[:10] == 1.0) + assert torch.all(weights[10:] == 0.0) + + +# ====================== Helper Method Tests ====================== + + +def test_linweights_normal_case(): + """Test _linweights with normal parameters.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = processor._linweights(start=5, end=15, total=20) + + # Should create linear weights from 1 to 0 + # Excluding the endpoints: linspace(1, 0, steps+2)[1:-1] + # Steps = total - (total - end) - start = 20 - 5 - 5 = 10 + assert len(weights) == 10 + assert weights[0] < 1.0 # First value after 1.0 + assert weights[-1] > 0.0 # Last value before 0.0 + assert torch.all(weights[:-1] >= weights[1:]) # Decreasing + + +def test_linweights_with_end_equals_start(): + """Test _linweights when end equals start.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = processor._linweights(start=10, end=10, total=20) + + # Should return empty tensor + assert len(weights) == 0 + + +def test_linweights_with_end_less_than_start(): + """Test _linweights when end < start.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = processor._linweights(start=15, end=10, total=20) + + # Should return empty tensor + assert len(weights) == 0 + + +def test_add_trailing_zeros_normal(): + """Test _add_trailing_zeros adds zeros correctly.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2]) + result = processor._add_trailing_zeros(weights, total=10, end=5) + + # Should add 5 zeros (total - end = 10 - 5 = 5) + assert len(result) == 10 + assert torch.all(result[:5] == weights) + assert torch.all(result[5:] == 0.0) + + +def test_add_trailing_zeros_no_zeros_needed(): + """Test _add_trailing_zeros when no zeros needed.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = torch.tensor([1.0, 0.8, 0.6]) + result = processor._add_trailing_zeros(weights, total=3, end=5) + + # zeros_len = 3 - 5 = -2 <= 0, so no zeros added + assert torch.equal(result, weights) + + +def test_add_leading_ones_normal(): + """Test _add_leading_ones adds ones correctly.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0]) + result = processor._add_leading_ones(weights, start=3, total=10) + + # Should add 3 ones at the start + assert len(result) == 8 + assert torch.all(result[:3] == 1.0) + assert torch.all(result[3:] == weights) + + +def test_add_leading_ones_no_ones_needed(): + """Test _add_leading_ones when no ones needed.""" + config = RTCConfig() + processor = RTCProcessor(config) + + weights = torch.tensor([0.8, 0.6, 0.4]) + result = processor._add_leading_ones(weights, start=0, total=10) + + # ones_len = 0, so no ones added + assert torch.equal(result, weights) + + +# ====================== denoise_step Tests ====================== + + +def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled): + """Test denoise_step without previous chunk (no guidance).""" + x_t = torch.randn(1, 50, 6) + + # Mock denoiser that returns fixed velocity + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = rtc_processor_debug_disabled.denoise_step( + x_t=x_t, + prev_chunk_left_over=None, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Should return v_t unchanged (no guidance) + expected = mock_denoiser(x_t) + assert torch.allclose(result, expected) + + +def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled): + """Test denoise_step with previous chunk applies guidance.""" + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = rtc_processor_debug_disabled.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Result should be different from base v_t (guidance applied) + base_v_t = mock_denoiser(x_t) + assert not torch.allclose(result, base_v_t) + + # Result should have same shape + assert result.shape == x_t.shape + + +def test_denoise_step_adds_batch_dimension(): + """Test denoise_step handles 2D input by adding batch dimension.""" + config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0) + processor = RTCProcessor(config) + + # 2D input (no batch dimension) + x_t = torch.randn(50, 6) + prev_chunk = torch.randn(50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Output should be 2D (batch dimension removed) + assert result.ndim == 2 + assert result.shape == (50, 6) + + +def test_denoise_step_pads_shorter_prev_chunk(): + """Test denoise_step pads previous chunk if shorter than x_t.""" + config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 30, 6) # Shorter than x_t + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Should complete successfully (padding happens internally) + assert result.shape == x_t.shape + + +def test_denoise_step_pads_fewer_action_dims(): + """Test denoise_step pads if prev_chunk has fewer action dimensions.""" + config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 4) # Fewer action dims + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Should complete successfully (padding happens internally) + assert result.shape == x_t.shape + + +def test_denoise_step_uses_custom_execution_horizon(): + """Test denoise_step uses custom execution_horizon parameter.""" + config = RTCConfig(execution_horizon=10) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + # Use custom execution_horizon + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + execution_horizon=20, # Override config value + ) + + assert result.shape == x_t.shape + + +def test_denoise_step_clamps_execution_horizon_to_prev_chunk_length(): + """Test denoise_step clamps execution_horizon if prev_chunk is shorter.""" + config = RTCConfig(execution_horizon=100) # Very large + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 20, 6) # Only 20 timesteps + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + # Should clamp execution_horizon to 20 internally + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + assert result.shape == x_t.shape + + +def test_denoise_step_guidance_weight_calculation(): + """Test denoise_step calculates guidance weight correctly.""" + config = RTCConfig(max_guidance_weight=10.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + # Time = 0.5 => tau = 1 - 0.5 = 0.5 + time = 0.5 + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=time, + original_denoise_step_partial=mock_denoiser, + ) + + # Should produce valid output + assert result.shape == x_t.shape + assert not torch.any(torch.isnan(result)) + assert not torch.any(torch.isinf(result)) + + +def test_denoise_step_guidance_weight_at_time_zero(): + """Test denoise_step handles time=0 (tau=1) without NaN/Inf.""" + config = RTCConfig(max_guidance_weight=10.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + # Time = 0 => tau = 1, c = (1-tau)/tau = 0/1 = 0 + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.0), + original_denoise_step_partial=mock_denoiser, + ) + + # Should handle gracefully (no NaN/Inf) + assert not torch.any(torch.isnan(result)) + assert not torch.any(torch.isinf(result)) + + +def test_denoise_step_guidance_weight_at_time_one(): + """Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping.""" + config = RTCConfig(max_guidance_weight=10.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + # Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight) + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(1.0), + original_denoise_step_partial=mock_denoiser, + ) + + # Should clamp to max_guidance_weight (no Inf) + assert not torch.any(torch.isinf(result)) + + +def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled): + """Test denoise_step tracks debug information when enabled.""" + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + rtc_processor_debug_enabled.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Should have tracked one step + steps = rtc_processor_debug_enabled.get_all_debug_steps() + assert len(steps) == 1 + + # Check tracked values + step = steps[0] + assert step.time == 0.5 + assert step.x1_t is not None + assert step.correction is not None + assert step.err is not None + assert step.weights is not None + assert step.guidance_weight is not None + assert step.inference_delay == 5 + + +def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled): + """Test denoise_step doesn't track when debug disabled.""" + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + rtc_processor_debug_disabled.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Should not track + steps = rtc_processor_debug_disabled.get_all_debug_steps() + assert len(steps) == 0 + + +# ====================== Integration Tests ====================== + + +def test_denoise_step_full_workflow(): + """Test complete denoise_step workflow.""" + config = RTCConfig( + enabled=True, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, + max_guidance_weight=5.0, + execution_horizon=10, + debug=True, + ) + processor = RTCProcessor(config) + + # Simulate two denoising steps + x_t1 = torch.randn(1, 50, 6) + x_t2 = torch.randn(1, 50, 6) + + def mock_denoiser(x): + return torch.randn_like(x) * 0.1 + + # First step - no guidance + result1 = processor.denoise_step( + x_t=x_t1, + prev_chunk_left_over=None, + inference_delay=5, + time=torch.tensor(0.8), + original_denoise_step_partial=mock_denoiser, + ) + + # Second step - with guidance + result2 = processor.denoise_step( + x_t=x_t2, + prev_chunk_left_over=result1, + inference_delay=5, + time=torch.tensor(0.6), + original_denoise_step_partial=mock_denoiser, + ) + + # Both should complete successfully + assert result1.shape == (1, 50, 6) + assert result2.shape == (1, 50, 6) + + # Should have tracked one step (second one, first had no prev_chunk) + steps = processor.get_all_debug_steps() + assert len(steps) == 1 + + +def test_get_prefix_weights_integration(): + """Test get_prefix_weights produces expected structure for all schedules.""" + schedules = [ + RTCAttentionSchedule.ZEROS, + RTCAttentionSchedule.ONES, + RTCAttentionSchedule.LINEAR, + RTCAttentionSchedule.EXP, + ] + + for schedule in schedules: + config = RTCConfig(prefix_attention_schedule=schedule) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=5, end=15, total=20) + + # All should have correct shape + assert weights.shape == (20,) + + # All should be in valid range [0, 1] + assert torch.all(weights >= 0.0) + assert torch.all(weights <= 1.0) + + # All should have no NaN or Inf + assert not torch.any(torch.isnan(weights)) + assert not torch.any(torch.isinf(weights)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_denoise_step_with_cuda_tensors(): + """Test denoise_step works with CUDA tensors.""" + config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0) + processor = RTCProcessor(config) + + x_t = torch.randn(1, 50, 6, device="cuda") + prev_chunk = torch.randn(1, 50, 6, device="cuda") + + def mock_denoiser(x): + return torch.ones_like(x) * 0.5 + + result = processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk, + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=mock_denoiser, + ) + + # Result should be on CUDA + assert result.device.type == "cuda" + assert result.shape == x_t.shape + + +def test_denoise_step_deterministic_with_same_inputs(): + """Test denoise_step produces same output with same inputs.""" + config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0) + processor = RTCProcessor(config) + + torch.manual_seed(42) + x_t = torch.randn(1, 50, 6) + prev_chunk = torch.randn(1, 50, 6) + + def deterministic_denoiser(x): + return torch.ones_like(x) * 0.5 + + result1 = processor.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=deterministic_denoiser, + ) + + result2 = processor.denoise_step( + x_t=x_t.clone(), + prev_chunk_left_over=prev_chunk.clone(), + inference_delay=5, + time=torch.tensor(0.5), + original_denoise_step_partial=deterministic_denoiser, + ) + + # Should produce identical results + assert torch.allclose(result1, result2)