mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Add more tests
This commit is contained in:
@@ -0,0 +1,825 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC ActionQueue module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_enabled():
|
||||
"""Create an RTC config with RTC enabled."""
|
||||
return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_disabled():
|
||||
"""Create an RTC config with RTC disabled."""
|
||||
return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_actions():
|
||||
"""Create sample action tensors for testing."""
|
||||
return {
|
||||
"original": torch.randn(50, 6), # (time_steps, action_dim)
|
||||
"processed": torch.randn(50, 6),
|
||||
"short": torch.randn(10, 6),
|
||||
"longer": torch.randn(100, 6),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_enabled(rtc_config_enabled):
|
||||
"""Create an ActionQueue with RTC enabled."""
|
||||
return ActionQueue(rtc_config_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_disabled(rtc_config_disabled):
|
||||
"""Create an ActionQueue with RTC disabled."""
|
||||
return ActionQueue(rtc_config_disabled)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
|
||||
"""Test ActionQueue initializes correctly with RTC enabled."""
|
||||
queue = ActionQueue(rtc_config_enabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is True
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
|
||||
"""Test ActionQueue initializes correctly with RTC disabled."""
|
||||
queue = ActionQueue(rtc_config_disabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is False
|
||||
|
||||
|
||||
# ====================== get() Tests ======================
|
||||
|
||||
|
||||
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get() returns None when queue is empty."""
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns actions in sequence."""
|
||||
# Initialize queue with actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Get first action
|
||||
action1 = action_queue_rtc_enabled.get()
|
||||
assert action1 is not None
|
||||
assert action1.shape == (6,)
|
||||
assert torch.equal(action1, sample_actions["processed"][0])
|
||||
|
||||
# Get second action
|
||||
action2 = action_queue_rtc_enabled.get()
|
||||
assert action2 is not None
|
||||
assert torch.equal(action2, sample_actions["processed"][1])
|
||||
|
||||
|
||||
def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns None after all actions are consumed."""
|
||||
# Use short action sequence
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
# Next get should return None
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() increments last_index correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
|
||||
# ====================== qsize() Tests ======================
|
||||
|
||||
|
||||
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
|
||||
"""Test qsize() returns 0 when queue is empty."""
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns correct number of remaining actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 9
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
|
||||
def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns 0 after queue is exhausted."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== empty() Tests ======================
|
||||
|
||||
|
||||
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
|
||||
"""Test empty() returns True when queue is empty."""
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False when queue has actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False after partial consumption."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns True after all actions consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
# ====================== get_action_index() Tests ======================
|
||||
|
||||
|
||||
def test_get_action_index_initial_value(action_queue_rtc_enabled):
|
||||
"""Test get_action_index() returns 0 initially."""
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_action_index() tracks consumption correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 3
|
||||
|
||||
|
||||
# ====================== get_left_over() Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get_left_over() returns None when queue is empty."""
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is None
|
||||
|
||||
|
||||
def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns all original actions when none consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (10, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"])
|
||||
|
||||
|
||||
def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns only remaining original actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (7, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"][3:])
|
||||
|
||||
|
||||
def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns empty tensor after all consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (0, 6)
|
||||
|
||||
|
||||
# ====================== merge() with RTC Enabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() replaces queue when RTC is enabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
# Consume some actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
# Merge new actions - should replace, not append
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
# Queue should be replaced with new actions minus delay
|
||||
# Original has 50 actions, delay is 5, so remaining is 45
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() correctly applies real_delay when RTC is enabled."""
|
||||
delay = 10
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay)
|
||||
|
||||
# Queue should have original length minus delay
|
||||
expected_size = len(sample_actions["original"]) - delay
|
||||
assert action_queue_rtc_enabled.qsize() == expected_size
|
||||
|
||||
# First action should be the one at index [delay]
|
||||
first_action = action_queue_rtc_enabled.get()
|
||||
assert torch.equal(first_action, sample_actions["processed"][delay])
|
||||
|
||||
|
||||
def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() resets last_index to 0 when RTC is enabled."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
# Merge new actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with zero delay keeps all actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with delay larger than action sequence."""
|
||||
# Delay is larger than sequence length
|
||||
delay = 100
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay)
|
||||
|
||||
# Queue should be empty (delay >= length)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== merge() with RTC Disabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() appends actions when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
initial_size = action_queue_rtc_disabled.qsize()
|
||||
assert initial_size == 10
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have appended
|
||||
assert action_queue_rtc_disabled.qsize() == initial_size + 10
|
||||
|
||||
|
||||
def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() removes consumed actions before appending when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.qsize() == 7
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have 7 remaining + 10 new = 17
|
||||
assert action_queue_rtc_disabled.qsize() == 17
|
||||
|
||||
|
||||
def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() resets last_index after appending when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.last_index == 2
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# last_index should be reset to 0
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() ignores real_delay parameter when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10)
|
||||
|
||||
# All actions should be in queue (delay ignored)
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() on first call with RTC disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
# ====================== merge() with Different Action Shapes Tests ======================
|
||||
|
||||
|
||||
def test_merge_with_different_action_dims():
|
||||
"""Test merge() handles actions with different dimensions."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Actions with 4 dimensions instead of 6
|
||||
actions_4d = torch.randn(20, 4)
|
||||
queue.merge(actions_4d, actions_4d, real_delay=5)
|
||||
|
||||
action = queue.get()
|
||||
assert action.shape == (4,)
|
||||
|
||||
|
||||
def test_merge_with_different_lengths():
|
||||
"""Test merge() handles action sequences of varying lengths."""
|
||||
cfg = RTCConfig(enabled=False, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Add sequences of different lengths
|
||||
queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0)
|
||||
assert queue.qsize() == 10
|
||||
|
||||
queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0)
|
||||
assert queue.qsize() == 35
|
||||
|
||||
|
||||
# ====================== merge() Delay Validation Tests ======================
|
||||
|
||||
|
||||
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() validates that real_delay matches action index difference."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with mismatched delay (should log warning)
|
||||
# We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0,
|
||||
# then indexes_diff=5, but if real_delay=3, it will warn
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=3,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Check warning was logged
|
||||
assert "Indexes diff is not equal to real delay" in caplog.text
|
||||
|
||||
|
||||
def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() doesn't warn when delays are consistent."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with matching delay
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Should not have warning
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() skips delay validation when action_index_before_inference is None."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Pass None for action_index_before_inference
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=999, # Doesn't matter
|
||||
action_index_before_inference=None,
|
||||
)
|
||||
|
||||
# Should not warn (validation skipped)
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
# ====================== Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() is thread-safe with multiple consumers."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(25):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
if action is not None:
|
||||
results.append(action)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=consumer) for _ in range(4)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed all actions (100 total, 4 threads * 25 each)
|
||||
assert len(results) == 100
|
||||
|
||||
# All results should be unique (no duplicate consumption)
|
||||
# We can verify by checking that indices are not duplicated
|
||||
# Since we don't track indices in results, we check total count is correct
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() is thread-safe with multiple producers."""
|
||||
errors = []
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(5):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=producer) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150)
|
||||
assert action_queue_rtc_disabled.qsize() == 150
|
||||
|
||||
|
||||
def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test concurrent get() and merge() operations."""
|
||||
errors = []
|
||||
consumed_count = [0]
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(50):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
if action is not None:
|
||||
consumed_count[0] += 1
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_threads = [threading.Thread(target=consumer) for _ in range(2)]
|
||||
producer_threads = [threading.Thread(target=producer) for _ in range(2)]
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.start()
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed some or all actions (non-deterministic due to timing)
|
||||
# Total produced: 2 producers * 10 merges * 10 actions = 200
|
||||
# Total consumed attempts: 2 consumers * 50 = 100
|
||||
assert consumed_count[0] <= 200
|
||||
|
||||
|
||||
# ====================== get_left_over() Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() is thread-safe with concurrent access."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
errors = []
|
||||
leftovers = []
|
||||
|
||||
def reader():
|
||||
try:
|
||||
for _ in range(20):
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
if leftover is not None:
|
||||
leftovers.append(leftover.shape[0])
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=reader) for _ in range(3)]
|
||||
|
||||
# Also consume some actions concurrently
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
time.sleep(0.002)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
all_threads = threads + [consumer_thread]
|
||||
|
||||
for t in all_threads:
|
||||
t.start()
|
||||
|
||||
for t in all_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Leftovers should be monotonically decreasing or stable
|
||||
# (as actions are consumed, leftover size decreases)
|
||||
assert len(leftovers) > 0
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_queue_with_single_action(action_queue_rtc_enabled):
|
||||
"""Test queue behavior with a single action."""
|
||||
single_action_original = torch.randn(1, 6)
|
||||
single_action_processed = torch.randn(1, 6)
|
||||
|
||||
action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 1
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
assert action.shape == (6,)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test queue maintains correct state through multiple merge cycles."""
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume half
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge again
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() > 0
|
||||
|
||||
|
||||
def test_queue_with_all_zeros_actions(action_queue_rtc_enabled):
|
||||
"""Test queue handles all-zero action tensors."""
|
||||
zeros_actions = torch.zeros(20, 6)
|
||||
action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0)
|
||||
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert torch.all(action == 0)
|
||||
|
||||
|
||||
def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test that merge() clones input tensors, not storing references."""
|
||||
original_copy = sample_actions["original"].clone()
|
||||
processed_copy = sample_actions["processed"].clone()
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Modify original tensors
|
||||
sample_actions["original"].fill_(999.0)
|
||||
sample_actions["processed"].fill_(-999.0)
|
||||
|
||||
# Queue should have cloned values
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert not torch.equal(action, sample_actions["processed"][0])
|
||||
assert torch.equal(action, processed_copy[0])
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert not torch.equal(leftover, sample_actions["original"][1:])
|
||||
assert torch.equal(leftover, original_copy[1:])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_queue_handles_gpu_tensors():
|
||||
"""Test queue correctly handles GPU tensors."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
actions_gpu = torch.randn(20, 6, device="cuda")
|
||||
queue.merge(actions_gpu, actions_gpu, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.device.type == "cuda"
|
||||
|
||||
leftover = queue.get_left_over()
|
||||
assert leftover.device.type == "cuda"
|
||||
|
||||
|
||||
def test_queue_handles_different_dtypes():
|
||||
"""Test queue handles actions with different dtypes."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Use float64 instead of default float32
|
||||
actions_f64 = torch.randn(20, 6, dtype=torch.float64)
|
||||
queue.merge(actions_f64, actions_f64, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.dtype == torch.float64
|
||||
|
||||
|
||||
def test_empty_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test empty() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_qsize_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test qsize() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test a typical RTC workflow: merge, consume, merge with delay."""
|
||||
# First inference
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
initial_size = action_queue_rtc_enabled.qsize()
|
||||
assert initial_size == 50
|
||||
|
||||
# Consume 10 actions (execution_horizon)
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 40
|
||||
|
||||
# Second inference with delay
|
||||
action_index_before = action_queue_rtc_enabled.get_action_index()
|
||||
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=action_index_before,
|
||||
)
|
||||
|
||||
# Queue should be replaced, minus delay
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test a typical non-RTC workflow: merge, consume, merge again."""
|
||||
# First inference
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 50
|
||||
|
||||
# Consume 40 actions
|
||||
for _ in range(40):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Second inference (should append)
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Should have 10 remaining + 50 new = 60
|
||||
assert action_queue_rtc_disabled.qsize() == 60
|
||||
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC configuration module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_default_initialization():
|
||||
"""Test RTCConfig initializes with default values."""
|
||||
config = RTCConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
assert config.debug_maxlen == 100
|
||||
|
||||
|
||||
def test_rtc_config_custom_initialization():
|
||||
"""Test RTCConfig initializes with custom values."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=20,
|
||||
debug=True,
|
||||
debug_maxlen=200,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
assert config.max_guidance_weight == 5.0
|
||||
assert config.execution_horizon == 20
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 200
|
||||
|
||||
|
||||
def test_rtc_config_partial_initialization():
|
||||
"""Test RTCConfig with partial custom values."""
|
||||
config = RTCConfig(enabled=True, max_guidance_weight=15.0)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.max_guidance_weight == 15.0
|
||||
# Other values should be defaults
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
# ====================== Validation Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_validates_positive_max_guidance_weight():
|
||||
"""Test RTCConfig validates max_guidance_weight is positive."""
|
||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
||||
RTCConfig(max_guidance_weight=0.0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
||||
RTCConfig(max_guidance_weight=-1.0)
|
||||
|
||||
|
||||
def test_rtc_config_validates_positive_debug_maxlen():
|
||||
"""Test RTCConfig validates debug_maxlen is positive."""
|
||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
||||
RTCConfig(debug_maxlen=0)
|
||||
|
||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
||||
RTCConfig(debug_maxlen=-10)
|
||||
|
||||
|
||||
def test_rtc_config_accepts_valid_max_guidance_weight():
|
||||
"""Test RTCConfig accepts valid positive max_guidance_weight."""
|
||||
config1 = RTCConfig(max_guidance_weight=0.1)
|
||||
assert config1.max_guidance_weight == 0.1
|
||||
|
||||
config2 = RTCConfig(max_guidance_weight=100.0)
|
||||
assert config2.max_guidance_weight == 100.0
|
||||
|
||||
|
||||
def test_rtc_config_accepts_valid_debug_maxlen():
|
||||
"""Test RTCConfig accepts valid positive debug_maxlen."""
|
||||
config1 = RTCConfig(debug_maxlen=1)
|
||||
assert config1.debug_maxlen == 1
|
||||
|
||||
config2 = RTCConfig(debug_maxlen=10000)
|
||||
assert config2.debug_maxlen == 10000
|
||||
|
||||
|
||||
# ====================== Attention Schedule Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_with_linear_schedule():
|
||||
"""Test RTCConfig with LINEAR attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
|
||||
|
||||
def test_rtc_config_with_exp_schedule():
|
||||
"""Test RTCConfig with EXP attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
|
||||
|
||||
def test_rtc_config_with_zeros_schedule():
|
||||
"""Test RTCConfig with ZEROS attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS
|
||||
|
||||
|
||||
def test_rtc_config_with_ones_schedule():
|
||||
"""Test RTCConfig with ONES attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES
|
||||
|
||||
|
||||
# ====================== Enabled/Disabled Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_enabled_true():
|
||||
"""Test RTCConfig with enabled=True."""
|
||||
config = RTCConfig(enabled=True)
|
||||
assert config.enabled is True
|
||||
|
||||
|
||||
def test_rtc_config_enabled_false():
|
||||
"""Test RTCConfig with enabled=False."""
|
||||
config = RTCConfig(enabled=False)
|
||||
assert config.enabled is False
|
||||
|
||||
|
||||
# ====================== Debug Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_debug_enabled():
|
||||
"""Test RTCConfig with debug enabled."""
|
||||
config = RTCConfig(debug=True, debug_maxlen=500)
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 500
|
||||
|
||||
|
||||
def test_rtc_config_debug_disabled():
|
||||
"""Test RTCConfig with debug disabled."""
|
||||
config = RTCConfig(debug=False)
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
# ====================== Execution Horizon Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_with_small_execution_horizon():
|
||||
"""Test RTCConfig with small execution horizon."""
|
||||
config = RTCConfig(execution_horizon=1)
|
||||
assert config.execution_horizon == 1
|
||||
|
||||
|
||||
def test_rtc_config_with_large_execution_horizon():
|
||||
"""Test RTCConfig with large execution horizon."""
|
||||
config = RTCConfig(execution_horizon=100)
|
||||
assert config.execution_horizon == 100
|
||||
|
||||
|
||||
def test_rtc_config_with_zero_execution_horizon():
|
||||
"""Test RTCConfig accepts zero execution horizon."""
|
||||
# No validation on execution_horizon, so this should work
|
||||
config = RTCConfig(execution_horizon=0)
|
||||
assert config.execution_horizon == 0
|
||||
|
||||
|
||||
def test_rtc_config_with_negative_execution_horizon():
|
||||
"""Test RTCConfig accepts negative execution horizon."""
|
||||
# No validation on execution_horizon, so this should work
|
||||
config = RTCConfig(execution_horizon=-1)
|
||||
assert config.execution_horizon == -1
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_typical_production_settings():
|
||||
"""Test RTCConfig with typical production settings."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=8,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 8
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
def test_rtc_config_typical_debug_settings():
|
||||
"""Test RTCConfig with typical debug settings."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
debug_maxlen=1000,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 1000
|
||||
|
||||
|
||||
def test_rtc_config_disabled_mode():
|
||||
"""Test RTCConfig in disabled mode."""
|
||||
config = RTCConfig(enabled=False)
|
||||
|
||||
assert config.enabled is False
|
||||
# Other settings still accessible even when disabled
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
|
||||
|
||||
# ====================== Dataclass Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_is_dataclass():
|
||||
"""Test that RTCConfig is a dataclass."""
|
||||
from dataclasses import is_dataclass
|
||||
|
||||
assert is_dataclass(RTCConfig)
|
||||
|
||||
|
||||
def test_rtc_config_equality():
|
||||
"""Test RTCConfig equality comparison."""
|
||||
config1 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
||||
config2 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
||||
config3 = RTCConfig(enabled=False, max_guidance_weight=5.0)
|
||||
|
||||
assert config1 == config2
|
||||
assert config1 != config3
|
||||
|
||||
|
||||
def test_rtc_config_repr():
|
||||
"""Test RTCConfig string representation."""
|
||||
config = RTCConfig(enabled=True, execution_horizon=20)
|
||||
repr_str = repr(config)
|
||||
|
||||
assert "RTCConfig" in repr_str
|
||||
assert "enabled=True" in repr_str
|
||||
assert "execution_horizon=20" in repr_str
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_very_small_max_guidance_weight():
|
||||
"""Test RTCConfig with very small positive max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=1e-10)
|
||||
assert config.max_guidance_weight == pytest.approx(1e-10)
|
||||
|
||||
|
||||
def test_rtc_config_very_large_max_guidance_weight():
|
||||
"""Test RTCConfig with very large max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=1e10)
|
||||
assert config.max_guidance_weight == pytest.approx(1e10)
|
||||
|
||||
|
||||
def test_rtc_config_minimum_debug_maxlen():
|
||||
"""Test RTCConfig with minimum valid debug_maxlen."""
|
||||
config = RTCConfig(debug_maxlen=1)
|
||||
assert config.debug_maxlen == 1
|
||||
|
||||
|
||||
def test_rtc_config_float_max_guidance_weight():
|
||||
"""Test RTCConfig with float max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=3.14159)
|
||||
assert config.max_guidance_weight == pytest.approx(3.14159)
|
||||
|
||||
|
||||
# ====================== Type Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_enabled_type():
|
||||
"""Test RTCConfig enabled field accepts boolean."""
|
||||
config = RTCConfig(enabled=True)
|
||||
assert isinstance(config.enabled, bool)
|
||||
|
||||
|
||||
def test_rtc_config_execution_horizon_type():
|
||||
"""Test RTCConfig execution_horizon field accepts integer."""
|
||||
config = RTCConfig(execution_horizon=15)
|
||||
assert isinstance(config.execution_horizon, int)
|
||||
|
||||
|
||||
def test_rtc_config_max_guidance_weight_type():
|
||||
"""Test RTCConfig max_guidance_weight field accepts float."""
|
||||
config = RTCConfig(max_guidance_weight=7.5)
|
||||
assert isinstance(config.max_guidance_weight, float)
|
||||
|
||||
|
||||
def test_rtc_config_debug_maxlen_type():
|
||||
"""Test RTCConfig debug_maxlen field accepts integer."""
|
||||
config = RTCConfig(debug_maxlen=200)
|
||||
assert isinstance(config.debug_maxlen, int)
|
||||
@@ -0,0 +1,427 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC debug visualizer module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_axes():
|
||||
"""Create mock matplotlib axes."""
|
||||
axes = []
|
||||
for _ in range(6):
|
||||
ax = MagicMock()
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
||||
axes.append(ax)
|
||||
return axes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensor_2d():
|
||||
"""Create a 2D sample tensor (time_steps, num_dims)."""
|
||||
return torch.randn(50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensor_3d():
|
||||
"""Create a 3D sample tensor (batch, time_steps, num_dims)."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_numpy_2d():
|
||||
"""Create a 2D numpy array."""
|
||||
return np.random.randn(50, 6)
|
||||
|
||||
|
||||
# ====================== Basic Plotting Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with 2D tensor."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
# Should call plot on each axis (6 dimensions)
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d):
|
||||
"""Test plot_waypoints with 3D tensor (batch dimension)."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d)
|
||||
|
||||
# Should still plot 6 dimensions (batch dimension removed)
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d):
|
||||
"""Test plot_waypoints with numpy array."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d)
|
||||
|
||||
# Should work with numpy arrays
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_none_tensor(mock_axes):
|
||||
"""Test plot_waypoints returns early when tensor is None."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, None)
|
||||
|
||||
# Should not call plot on any axis
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Parameter Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom color."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red")
|
||||
|
||||
# Check that color was passed to plot
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["color"] == "red"
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
||||
|
||||
# First axis should have label, others should not
|
||||
first_ax_kwargs = mock_axes[0].plot.call_args[1]
|
||||
assert first_ax_kwargs["label"] == "test_label"
|
||||
|
||||
# Other axes should have empty label
|
||||
for ax in mock_axes[1:]:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["label"] == ""
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom alpha."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["alpha"] == 0.5
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom linewidth."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["linewidth"] == 3
|
||||
|
||||
|
||||
def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with marker style."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["marker"] == "o"
|
||||
assert call_kwargs["markersize"] == 5
|
||||
|
||||
|
||||
def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints without marker (default)."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None)
|
||||
|
||||
# Marker should not be in kwargs when None
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert "marker" not in call_kwargs
|
||||
assert "markersize" not in call_kwargs
|
||||
|
||||
|
||||
# ====================== start_from Parameter Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with start_from=0."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
||||
|
||||
# X indices should start from 0
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert x_indices[0] == 0
|
||||
assert len(x_indices) == 50
|
||||
|
||||
|
||||
def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with start_from > 0."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10)
|
||||
|
||||
# X indices should start from 10
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert x_indices[0] == 10
|
||||
assert x_indices[-1] == 59 # 10 + 50 - 1
|
||||
|
||||
|
||||
# ====================== Tensor Shape Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_1d_tensor(mock_axes):
|
||||
"""Test plot_waypoints with 1D tensor."""
|
||||
tensor_1d = torch.randn(50)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d)
|
||||
|
||||
# Should reshape to (50, 1) and plot on first axis only
|
||||
mock_axes[0].plot.assert_called_once()
|
||||
for ax in mock_axes[1:]:
|
||||
ax.plot.assert_not_called()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d):
|
||||
"""Test plot_waypoints when tensor has fewer dims than axes."""
|
||||
# Create tensor with only 3 dimensions
|
||||
tensor_3d = sample_tensor_2d[:, :3]
|
||||
|
||||
# Create 6 axes but tensor only has 3 dims
|
||||
mock_axes = [MagicMock() for _ in range(6)]
|
||||
for ax in mock_axes:
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d)
|
||||
|
||||
# Should only plot on first 3 axes
|
||||
for i in range(3):
|
||||
mock_axes[i].plot.assert_called_once()
|
||||
for i in range(3, 6):
|
||||
mock_axes[i].plot.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Axis Labeling Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints sets x-axis label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for ax in mock_axes:
|
||||
ax.set_xlabel.assert_called_once_with("Step", fontsize=10)
|
||||
|
||||
|
||||
def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints sets y-axis label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for i, ax in enumerate(mock_axes):
|
||||
ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10)
|
||||
|
||||
|
||||
def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d):
|
||||
"""Test plot_waypoints doesn't set labels if they already exist."""
|
||||
mock_axes_with_labels = []
|
||||
for _ in range(6):
|
||||
ax = MagicMock()
|
||||
# Simulate existing labels
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label"
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label"
|
||||
mock_axes_with_labels.append(ax)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d)
|
||||
|
||||
# Should not set labels when they already exist
|
||||
for ax in mock_axes_with_labels:
|
||||
ax.set_xlabel.assert_not_called()
|
||||
ax.set_ylabel.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Grid Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints enables grid on all axes."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for ax in mock_axes:
|
||||
ax.grid.assert_called_once_with(True, alpha=0.3)
|
||||
|
||||
|
||||
# ====================== Legend Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints adds legend when label is provided."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
||||
|
||||
# Should add legend to first axis only
|
||||
mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8)
|
||||
|
||||
# Should not add legend to other axes
|
||||
for ax in mock_axes[1:]:
|
||||
ax.legend.assert_not_called()
|
||||
|
||||
|
||||
def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints doesn't add legend when no label provided."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="")
|
||||
|
||||
# Should not add legend to any axis
|
||||
for ax in mock_axes:
|
||||
ax.legend.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Data Correctness Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints plots correct tensor values."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
||||
|
||||
# Check first axis to verify data correctness
|
||||
call_args = mock_axes[0].plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
y_values = call_args[1]
|
||||
|
||||
# X indices should be 0 to 49
|
||||
np.testing.assert_array_equal(x_indices, np.arange(50))
|
||||
|
||||
# Y values should match first dimension of tensor
|
||||
expected_y = sample_tensor_2d[:, 0].cpu().numpy()
|
||||
np.testing.assert_array_almost_equal(y_values, expected_y)
|
||||
|
||||
|
||||
def test_plot_waypoints_handles_gpu_tensor(mock_axes):
|
||||
"""Test plot_waypoints handles GPU tensors (if CUDA available)."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
tensor_gpu = torch.randn(50, 6, device="cuda")
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu)
|
||||
|
||||
# Should successfully plot without errors
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_empty_tensor(mock_axes):
|
||||
"""Test plot_waypoints with empty tensor."""
|
||||
empty_tensor = torch.empty(0, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor)
|
||||
|
||||
# Should plot empty data
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 0
|
||||
|
||||
|
||||
def test_plot_waypoints_with_single_timestep(mock_axes):
|
||||
"""Test plot_waypoints with single timestep."""
|
||||
single_step_tensor = torch.randn(1, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor)
|
||||
|
||||
# Should plot single point
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 1
|
||||
|
||||
|
||||
def test_plot_waypoints_with_very_large_tensor(mock_axes):
|
||||
"""Test plot_waypoints with very large tensor."""
|
||||
large_tensor = torch.randn(10000, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor)
|
||||
|
||||
# Should handle large tensors
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 10000
|
||||
|
||||
|
||||
# ====================== Multiple Calls Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d):
|
||||
"""Test multiple plot_waypoints calls on same axes."""
|
||||
tensor1 = sample_tensor_2d
|
||||
tensor2 = torch.randn(50, 6)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1")
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2")
|
||||
|
||||
# Each axis should have been called twice
|
||||
for ax in mock_axes:
|
||||
assert ax.plot.call_count == 2
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with typical usage pattern."""
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2
|
||||
)
|
||||
|
||||
# Verify all expected calls were made
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
ax.set_xlabel.assert_called_once()
|
||||
ax.set_ylabel.assert_called_once()
|
||||
ax.grid.assert_called_once()
|
||||
|
||||
# First axis should have legend
|
||||
mock_axes[0].legend.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with all parameters specified."""
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axes=mock_axes,
|
||||
tensor=sample_tensor_2d,
|
||||
start_from=10,
|
||||
color="green",
|
||||
label="Full Test",
|
||||
alpha=0.8,
|
||||
linewidth=3,
|
||||
marker="o",
|
||||
markersize=6,
|
||||
)
|
||||
|
||||
# Check first axis for all parameters
|
||||
call_kwargs = mock_axes[0].plot.call_args[1]
|
||||
assert call_kwargs["color"] == "green"
|
||||
assert call_kwargs["label"] == "Full Test"
|
||||
assert call_kwargs["alpha"] == 0.8
|
||||
assert call_kwargs["linewidth"] == 3
|
||||
assert call_kwargs["marker"] == "o"
|
||||
assert call_kwargs["markersize"] == 6
|
||||
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC LatencyTracker module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker():
|
||||
"""Create a LatencyTracker with default maxlen."""
|
||||
return LatencyTracker(maxlen=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_tracker():
|
||||
"""Create a LatencyTracker with small maxlen for overflow testing."""
|
||||
return LatencyTracker(maxlen=5)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_latency_tracker_initialization():
|
||||
"""Test LatencyTracker initializes correctly."""
|
||||
tracker = LatencyTracker(maxlen=50)
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_latency_tracker_default_maxlen():
|
||||
"""Test LatencyTracker uses default maxlen."""
|
||||
tracker = LatencyTracker()
|
||||
# Should accept default maxlen=100
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
# ====================== add() Tests ======================
|
||||
|
||||
|
||||
def test_add_single_latency(tracker):
|
||||
"""Test adding a single latency value."""
|
||||
tracker.add(0.5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_multiple_latencies(tracker):
|
||||
"""Test adding multiple latency values."""
|
||||
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert len(tracker) == 5
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_add_negative_latency_ignored(tracker):
|
||||
"""Test that negative latencies are ignored."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(-0.1)
|
||||
tracker.add(0.3)
|
||||
|
||||
# Should only have 2 valid latencies
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_zero_latency(tracker):
|
||||
"""Test adding zero latency."""
|
||||
tracker.add(0.0)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_add_converts_to_float(tracker):
|
||||
"""Test add() converts input to float."""
|
||||
tracker.add(5) # Integer
|
||||
tracker.add("3.5") # String
|
||||
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_updates_max_latency(tracker):
|
||||
"""Test that max_latency is updated correctly."""
|
||||
tracker.add(0.5)
|
||||
assert tracker.max_latency == 0.5
|
||||
|
||||
tracker.add(0.3)
|
||||
assert tracker.max_latency == 0.5 # Should not decrease
|
||||
|
||||
tracker.add(0.9)
|
||||
assert tracker.max_latency == 0.9 # Should increase
|
||||
|
||||
|
||||
# ====================== reset() Tests ======================
|
||||
|
||||
|
||||
def test_reset_clears_values(tracker):
|
||||
"""Test reset() clears all values."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 3
|
||||
|
||||
tracker.reset()
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_clears_max_latency(tracker):
|
||||
"""Test reset() resets max_latency."""
|
||||
tracker.add(1.5)
|
||||
assert tracker.max_latency == 1.5
|
||||
|
||||
tracker.reset()
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_allows_new_values(tracker):
|
||||
"""Test that tracker works correctly after reset."""
|
||||
tracker.add(0.5)
|
||||
tracker.reset()
|
||||
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.3
|
||||
|
||||
|
||||
# ====================== max() Tests ======================
|
||||
|
||||
|
||||
def test_max_returns_zero_when_empty(tracker):
|
||||
"""Test max() returns 0.0 when tracker is empty."""
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_max_returns_maximum_value(tracker):
|
||||
"""Test max() returns the maximum latency."""
|
||||
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_max_persists_after_sliding_window(small_tracker):
|
||||
"""Test max() persists even after values slide out of window."""
|
||||
# Add values that will exceed maxlen=5
|
||||
small_tracker.add(0.1)
|
||||
small_tracker.add(0.9) # This is max
|
||||
small_tracker.add(0.2)
|
||||
small_tracker.add(0.3)
|
||||
small_tracker.add(0.4)
|
||||
small_tracker.add(0.5) # This pushes out 0.1
|
||||
|
||||
# Max should still be 0.9 even though only last 5 values kept
|
||||
assert small_tracker.max() == 0.9
|
||||
|
||||
|
||||
def test_max_after_reset(tracker):
|
||||
"""Test max() returns 0.0 after reset."""
|
||||
tracker.add(1.5)
|
||||
tracker.reset()
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
# ====================== percentile() Tests ======================
|
||||
|
||||
|
||||
def test_percentile_returns_zero_when_empty(tracker):
|
||||
"""Test percentile() returns 0.0 when tracker is empty."""
|
||||
assert tracker.percentile(0.5) == 0.0
|
||||
assert tracker.percentile(0.95) == 0.0
|
||||
|
||||
|
||||
def test_percentile_median(tracker):
|
||||
"""Test percentile(0.5) returns median."""
|
||||
# Add sorted values for easier verification
|
||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
for v in values:
|
||||
tracker.add(v)
|
||||
|
||||
# Median should be around 0.5
|
||||
median = tracker.percentile(0.5)
|
||||
assert 0.45 <= median <= 0.55
|
||||
|
||||
|
||||
def test_percentile_minimum_with_zero(tracker):
|
||||
"""Test percentile(0.0) returns minimum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(0.0) == 0.2
|
||||
|
||||
|
||||
def test_percentile_maximum_with_one(tracker):
|
||||
"""Test percentile(1.0) returns maximum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(1.0) == 0.8
|
||||
|
||||
|
||||
def test_percentile_95(tracker):
|
||||
"""Test percentile(0.95) returns 95th percentile."""
|
||||
# Add 100 values from 0.0 to 0.99
|
||||
for i in range(100):
|
||||
tracker.add(i / 100.0)
|
||||
|
||||
p95 = tracker.percentile(0.95)
|
||||
# 95th percentile should be around 0.95
|
||||
assert 0.93 <= p95 <= 0.96
|
||||
|
||||
|
||||
def test_percentile_negative_value_returns_min(tracker):
|
||||
"""Test percentile with negative q returns minimum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(-0.5) == 0.2
|
||||
|
||||
|
||||
def test_percentile_value_greater_than_one_returns_max(tracker):
|
||||
"""Test percentile with q > 1.0 returns maximum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(1.5) == 0.8
|
||||
|
||||
|
||||
# ====================== p95() Tests ======================
|
||||
|
||||
|
||||
def test_p95_returns_zero_when_empty(tracker):
|
||||
"""Test p95() returns 0.0 when tracker is empty."""
|
||||
assert tracker.p95() == 0.0
|
||||
|
||||
|
||||
def test_p95_returns_95th_percentile(tracker):
|
||||
"""Test p95() returns the 95th percentile."""
|
||||
# Add 100 values
|
||||
for i in range(100):
|
||||
tracker.add(i / 100.0)
|
||||
|
||||
p95 = tracker.p95()
|
||||
assert 0.93 <= p95 <= 0.96
|
||||
|
||||
|
||||
def test_p95_equals_percentile_95(tracker):
|
||||
"""Test p95() equals percentile(0.95)."""
|
||||
for i in range(50):
|
||||
tracker.add(i / 50.0)
|
||||
|
||||
assert tracker.p95() == tracker.percentile(0.95)
|
||||
|
||||
|
||||
# ====================== __len__() Tests ======================
|
||||
|
||||
|
||||
def test_len_returns_zero_initially(tracker):
|
||||
"""Test __len__ returns 0 for new tracker."""
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
def test_len_increments_with_add(tracker):
|
||||
"""Test __len__ increments as values are added."""
|
||||
assert len(tracker) == 0
|
||||
|
||||
tracker.add(0.1)
|
||||
assert len(tracker) == 1
|
||||
|
||||
tracker.add(0.2)
|
||||
assert len(tracker) == 2
|
||||
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 3
|
||||
|
||||
|
||||
def test_len_respects_maxlen(small_tracker):
|
||||
"""Test __len__ respects maxlen limit."""
|
||||
# Add more than maxlen values
|
||||
for i in range(10):
|
||||
small_tracker.add(i / 10.0)
|
||||
|
||||
# Should only keep last 5
|
||||
assert len(small_tracker) == 5
|
||||
|
||||
|
||||
def test_len_after_reset(tracker):
|
||||
"""Test __len__ returns 0 after reset."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 2
|
||||
|
||||
tracker.reset()
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
# ====================== Sliding Window Tests ======================
|
||||
|
||||
|
||||
def test_sliding_window_removes_oldest(small_tracker):
|
||||
"""Test sliding window removes oldest values."""
|
||||
# Add 7 values to tracker with maxlen=5
|
||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
for v in values:
|
||||
small_tracker.add(v)
|
||||
|
||||
# Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
assert len(small_tracker) == 5
|
||||
|
||||
# Median should reflect last 5 values
|
||||
median = small_tracker.percentile(0.5)
|
||||
assert 0.45 <= median <= 0.55
|
||||
|
||||
|
||||
def test_sliding_window_maintains_max(small_tracker):
|
||||
"""Test sliding window maintains correct max even after overflow."""
|
||||
small_tracker.add(0.1)
|
||||
small_tracker.add(0.9)
|
||||
small_tracker.add(0.2)
|
||||
small_tracker.add(0.3)
|
||||
small_tracker.add(0.4)
|
||||
small_tracker.add(0.5) # Pushes out 0.1
|
||||
|
||||
# Max should still be 0.9
|
||||
assert small_tracker.max() == 0.9
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_single_value(tracker):
|
||||
"""Test tracker behavior with single value."""
|
||||
tracker.add(0.75)
|
||||
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.75
|
||||
assert tracker.percentile(0.0) == 0.75
|
||||
assert tracker.percentile(0.5) == 0.75
|
||||
assert tracker.percentile(1.0) == 0.75
|
||||
|
||||
|
||||
def test_all_same_values(tracker):
|
||||
"""Test tracker with all identical values."""
|
||||
for _ in range(10):
|
||||
tracker.add(0.5)
|
||||
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.5
|
||||
assert tracker.percentile(0.0) == 0.5
|
||||
assert tracker.percentile(0.5) == 0.5
|
||||
assert tracker.percentile(1.0) == 0.5
|
||||
|
||||
|
||||
def test_very_small_values(tracker):
|
||||
"""Test tracker with very small float values."""
|
||||
tracker.add(1e-10)
|
||||
tracker.add(2e-10)
|
||||
tracker.add(3e-10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e-10)
|
||||
|
||||
|
||||
def test_very_large_values(tracker):
|
||||
"""Test tracker with very large float values."""
|
||||
tracker.add(1e10)
|
||||
tracker.add(2e10)
|
||||
tracker.add(3e10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e10)
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_usage_pattern(tracker):
|
||||
"""Test a typical usage pattern of the tracker."""
|
||||
# Simulate adding latencies over time
|
||||
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
|
||||
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
# Check statistics
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.15
|
||||
|
||||
# p95 should be close to max since we have only 10 values
|
||||
p95 = tracker.p95()
|
||||
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
|
||||
assert p95 <= tracker.max() # p95 should be <= max
|
||||
|
||||
|
||||
def test_reset_and_reuse(tracker):
|
||||
"""Test resetting and reusing tracker."""
|
||||
# First batch
|
||||
tracker.add(1.0)
|
||||
tracker.add(2.0)
|
||||
assert tracker.max() == 2.0
|
||||
|
||||
# Reset
|
||||
tracker.reset()
|
||||
|
||||
# Second batch
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.8
|
||||
assert tracker.percentile(0.5) <= 0.8
|
||||
|
||||
|
||||
def test_continuous_monitoring(small_tracker):
|
||||
"""Test continuous monitoring with sliding window."""
|
||||
# Simulate continuous latency monitoring
|
||||
# First 5 latencies
|
||||
for i in range(5):
|
||||
small_tracker.add(0.1 * (i + 1))
|
||||
|
||||
max_before = small_tracker.max()
|
||||
|
||||
# Add 5 more (window slides)
|
||||
for i in range(5, 10):
|
||||
small_tracker.add(0.1 * (i + 1))
|
||||
|
||||
# Max should have increased
|
||||
assert small_tracker.max() > max_before
|
||||
assert len(small_tracker) == 5 # Window size maintained
|
||||
|
||||
|
||||
# ====================== Type Conversion Tests ======================
|
||||
|
||||
|
||||
def test_add_with_integer(tracker):
|
||||
"""Test adding integer values."""
|
||||
tracker.add(5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_with_string_number(tracker):
|
||||
"""Test adding string representation of number."""
|
||||
tracker.add("3.14")
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_percentile_converts_q_to_float(tracker):
|
||||
"""Test percentile converts q parameter to float."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
|
||||
# Pass integer q
|
||||
result = tracker.percentile(1)
|
||||
assert result == 0.8
|
||||
@@ -0,0 +1,784 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC modeling module (RTCProcessor)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_enabled():
|
||||
"""Create RTC config with debug enabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
debug_maxlen=100,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_disabled():
|
||||
"""Create RTC config with debug disabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
|
||||
"""Create RTCProcessor with debug enabled."""
|
||||
return RTCProcessor(rtc_config_debug_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
|
||||
"""Create RTCProcessor with debug disabled."""
|
||||
return RTCProcessor(rtc_config_debug_disabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_x_t():
|
||||
"""Create sample x_t tensor (batch, time, action_dim)."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prev_chunk():
|
||||
"""Create sample previous chunk tensor."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
|
||||
"""Test RTCProcessor initializes with debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_enabled)
|
||||
assert processor.rtc_config == rtc_config_debug_enabled
|
||||
assert processor.tracker is not None
|
||||
assert processor.tracker.enabled is True
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
|
||||
"""Test RTCProcessor initializes without debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_disabled)
|
||||
assert processor.rtc_config == rtc_config_debug_disabled
|
||||
assert processor.tracker is None
|
||||
|
||||
|
||||
# ====================== Tracker Proxy Methods Tests ======================
|
||||
|
||||
|
||||
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test track() forwards to tracker when enabled."""
|
||||
rtc_processor_debug_enabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
guidance_weight=2.0,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].time == 0.5
|
||||
|
||||
|
||||
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
|
||||
"""Test track() does nothing when tracker disabled."""
|
||||
# Should not raise error
|
||||
rtc_processor_debug_disabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
)
|
||||
|
||||
# Should return empty list
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test get_all_debug_steps() returns tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test get_all_debug_steps() returns empty list when disabled."""
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert steps == []
|
||||
assert isinstance(steps, list)
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
|
||||
"""Test is_debug_enabled() returns True when tracker enabled."""
|
||||
assert rtc_processor_debug_enabled.is_debug_enabled() is True
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
|
||||
"""Test is_debug_enabled() returns False when tracker disabled."""
|
||||
assert rtc_processor_debug_disabled.is_debug_enabled() is False
|
||||
|
||||
|
||||
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test reset_tracker() clears tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
|
||||
|
||||
rtc_processor_debug_enabled.reset_tracker()
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
|
||||
|
||||
|
||||
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test reset_tracker() doesn't error when tracker disabled."""
|
||||
rtc_processor_debug_disabled.reset_tracker() # Should not raise
|
||||
|
||||
|
||||
# ====================== get_prefix_weights Tests ======================
|
||||
|
||||
|
||||
def test_get_prefix_weights_zeros_schedule():
|
||||
"""Test get_prefix_weights with ZEROS schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=10, total=20)
|
||||
|
||||
# First 5 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
assert torch.all(weights[5:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_ones_schedule():
|
||||
"""Test get_prefix_weights with ONES schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# First 15 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:15] == 1.0)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_linear_schedule():
|
||||
"""Test get_prefix_weights with LINEAR schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (20,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section (5:15) should be linearly decreasing from 1 to 0
|
||||
middle_weights = weights[5:15]
|
||||
assert middle_weights[0] > middle_weights[-1] # Decreasing
|
||||
assert torch.all(middle_weights >= 0.0)
|
||||
assert torch.all(middle_weights <= 1.0)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_exp_schedule():
|
||||
"""Test get_prefix_weights with EXP schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (20,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section should be exponentially weighted
|
||||
middle_weights = weights[5:15]
|
||||
assert torch.all(middle_weights >= 0.0)
|
||||
assert torch.all(middle_weights <= 1.0)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_end():
|
||||
"""Test get_prefix_weights when start equals end."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_greater_than_end():
|
||||
"""Test get_prefix_weights when start > end (gets clamped)."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# start > end should use min(start, end) = end
|
||||
weights = processor.get_prefix_weights(start=15, end=10, total=20)
|
||||
|
||||
# Should have ones up to end (10), then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
# ====================== Helper Method Tests ======================
|
||||
|
||||
|
||||
def test_linweights_normal_case():
|
||||
"""Test _linweights with normal parameters."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=5, end=15, total=20)
|
||||
|
||||
# Should create linear weights from 1 to 0
|
||||
# Excluding the endpoints: linspace(1, 0, steps+2)[1:-1]
|
||||
# Steps = total - (total - end) - start = 20 - 5 - 5 = 10
|
||||
assert len(weights) == 10
|
||||
assert weights[0] < 1.0 # First value after 1.0
|
||||
assert weights[-1] > 0.0 # Last value before 0.0
|
||||
assert torch.all(weights[:-1] >= weights[1:]) # Decreasing
|
||||
|
||||
|
||||
def test_linweights_with_end_equals_start():
|
||||
"""Test _linweights when end equals start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=10, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_linweights_with_end_less_than_start():
|
||||
"""Test _linweights when end < start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=15, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_add_trailing_zeros_normal():
|
||||
"""Test _add_trailing_zeros adds zeros correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
|
||||
result = processor._add_trailing_zeros(weights, total=10, end=5)
|
||||
|
||||
# Should add 5 zeros (total - end = 10 - 5 = 5)
|
||||
assert len(result) == 10
|
||||
assert torch.all(result[:5] == weights)
|
||||
assert torch.all(result[5:] == 0.0)
|
||||
|
||||
|
||||
def test_add_trailing_zeros_no_zeros_needed():
|
||||
"""Test _add_trailing_zeros when no zeros needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6])
|
||||
result = processor._add_trailing_zeros(weights, total=3, end=5)
|
||||
|
||||
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_normal():
|
||||
"""Test _add_leading_ones adds ones correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
result = processor._add_leading_ones(weights, start=3, total=10)
|
||||
|
||||
# Should add 3 ones at the start
|
||||
assert len(result) == 8
|
||||
assert torch.all(result[:3] == 1.0)
|
||||
assert torch.all(result[3:] == weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_no_ones_needed():
|
||||
"""Test _add_leading_ones when no ones needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4])
|
||||
result = processor._add_leading_ones(weights, start=0, total=10)
|
||||
|
||||
# ones_len = 0, so no ones added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
# ====================== denoise_step Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step without previous chunk (no guidance)."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
|
||||
# Mock denoiser that returns fixed velocity
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should return v_t unchanged (no guidance)
|
||||
expected = mock_denoiser(x_t)
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step with previous chunk applies guidance."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Result should be different from base v_t (guidance applied)
|
||||
base_v_t = mock_denoiser(x_t)
|
||||
assert not torch.allclose(result, base_v_t)
|
||||
|
||||
# Result should have same shape
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_adds_batch_dimension():
|
||||
"""Test denoise_step handles 2D input by adding batch dimension."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# 2D input (no batch dimension)
|
||||
x_t = torch.randn(50, 6)
|
||||
prev_chunk = torch.randn(50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Output should be 2D (batch dimension removed)
|
||||
assert result.ndim == 2
|
||||
assert result.shape == (50, 6)
|
||||
|
||||
|
||||
def test_denoise_step_pads_shorter_prev_chunk():
|
||||
"""Test denoise_step pads previous chunk if shorter than x_t."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 30, 6) # Shorter than x_t
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should complete successfully (padding happens internally)
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_pads_fewer_action_dims():
|
||||
"""Test denoise_step pads if prev_chunk has fewer action dimensions."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 4) # Fewer action dims
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should complete successfully (padding happens internally)
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_uses_custom_execution_horizon():
|
||||
"""Test denoise_step uses custom execution_horizon parameter."""
|
||||
config = RTCConfig(execution_horizon=10)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Use custom execution_horizon
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
execution_horizon=20, # Override config value
|
||||
)
|
||||
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_clamps_execution_horizon_to_prev_chunk_length():
|
||||
"""Test denoise_step clamps execution_horizon if prev_chunk is shorter."""
|
||||
config = RTCConfig(execution_horizon=100) # Very large
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 20, 6) # Only 20 timesteps
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Should clamp execution_horizon to 20 internally
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_calculation():
|
||||
"""Test denoise_step calculates guidance weight correctly."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Time = 0.5 => tau = 1 - 0.5 = 0.5
|
||||
time = 0.5
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=time,
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should produce valid output
|
||||
assert result.shape == x_t.shape
|
||||
assert not torch.any(torch.isnan(result))
|
||||
assert not torch.any(torch.isinf(result))
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_zero():
|
||||
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Time = 0 => tau = 1, c = (1-tau)/tau = 0/1 = 0
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should handle gracefully (no NaN/Inf)
|
||||
assert not torch.any(torch.isnan(result))
|
||||
assert not torch.any(torch.isinf(result))
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_one():
|
||||
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(1.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should clamp to max_guidance_weight (no Inf)
|
||||
assert not torch.any(torch.isinf(result))
|
||||
|
||||
|
||||
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
|
||||
"""Test denoise_step tracks debug information when enabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_enabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
# Check tracked values
|
||||
step = steps[0]
|
||||
assert step.time == 0.5
|
||||
assert step.x1_t is not None
|
||||
assert step.correction is not None
|
||||
assert step.err is not None
|
||||
assert step.weights is not None
|
||||
assert step.guidance_weight is not None
|
||||
assert step.inference_delay == 5
|
||||
|
||||
|
||||
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step doesn't track when debug disabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should not track
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_full_workflow():
|
||||
"""Test complete denoise_step workflow."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# Simulate two denoising steps
|
||||
x_t1 = torch.randn(1, 50, 6)
|
||||
x_t2 = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.randn_like(x) * 0.1
|
||||
|
||||
# First step - no guidance
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t1,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.8),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Second step - with guidance
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t2,
|
||||
prev_chunk_left_over=result1,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.6),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Both should complete successfully
|
||||
assert result1.shape == (1, 50, 6)
|
||||
assert result2.shape == (1, 50, 6)
|
||||
|
||||
# Should have tracked one step (second one, first had no prev_chunk)
|
||||
steps = processor.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
|
||||
def test_get_prefix_weights_integration():
|
||||
"""Test get_prefix_weights produces expected structure for all schedules."""
|
||||
schedules = [
|
||||
RTCAttentionSchedule.ZEROS,
|
||||
RTCAttentionSchedule.ONES,
|
||||
RTCAttentionSchedule.LINEAR,
|
||||
RTCAttentionSchedule.EXP,
|
||||
]
|
||||
|
||||
for schedule in schedules:
|
||||
config = RTCConfig(prefix_attention_schedule=schedule)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# All should have correct shape
|
||||
assert weights.shape == (20,)
|
||||
|
||||
# All should be in valid range [0, 1]
|
||||
assert torch.all(weights >= 0.0)
|
||||
assert torch.all(weights <= 1.0)
|
||||
|
||||
# All should have no NaN or Inf
|
||||
assert not torch.any(torch.isnan(weights))
|
||||
assert not torch.any(torch.isinf(weights))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_denoise_step_with_cuda_tensors():
|
||||
"""Test denoise_step works with CUDA tensors."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6, device="cuda")
|
||||
prev_chunk = torch.randn(1, 50, 6, device="cuda")
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Result should be on CUDA
|
||||
assert result.device.type == "cuda"
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_deterministic_with_same_inputs():
|
||||
"""Test denoise_step produces same output with same inputs."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
torch.manual_seed(42)
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def deterministic_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
# Should produce identical results
|
||||
assert torch.allclose(result1, result2)
|
||||
Reference in New Issue
Block a user