mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Update README
This commit is contained in:
@@ -16,8 +16,6 @@
|
||||
|
||||
"""Tests for RTC configuration module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
@@ -65,259 +63,3 @@ def test_rtc_config_partial_initialization():
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
# ====================== Validation Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_validates_positive_max_guidance_weight():
|
||||
"""Test RTCConfig validates max_guidance_weight is positive."""
|
||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
||||
RTCConfig(max_guidance_weight=0.0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
||||
RTCConfig(max_guidance_weight=-1.0)
|
||||
|
||||
|
||||
def test_rtc_config_validates_positive_debug_maxlen():
|
||||
"""Test RTCConfig validates debug_maxlen is positive."""
|
||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
||||
RTCConfig(debug_maxlen=0)
|
||||
|
||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
||||
RTCConfig(debug_maxlen=-10)
|
||||
|
||||
|
||||
def test_rtc_config_accepts_valid_max_guidance_weight():
|
||||
"""Test RTCConfig accepts valid positive max_guidance_weight."""
|
||||
config1 = RTCConfig(max_guidance_weight=0.1)
|
||||
assert config1.max_guidance_weight == 0.1
|
||||
|
||||
config2 = RTCConfig(max_guidance_weight=100.0)
|
||||
assert config2.max_guidance_weight == 100.0
|
||||
|
||||
|
||||
def test_rtc_config_accepts_valid_debug_maxlen():
|
||||
"""Test RTCConfig accepts valid positive debug_maxlen."""
|
||||
config1 = RTCConfig(debug_maxlen=1)
|
||||
assert config1.debug_maxlen == 1
|
||||
|
||||
config2 = RTCConfig(debug_maxlen=10000)
|
||||
assert config2.debug_maxlen == 10000
|
||||
|
||||
|
||||
# ====================== Attention Schedule Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_with_linear_schedule():
|
||||
"""Test RTCConfig with LINEAR attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
|
||||
|
||||
def test_rtc_config_with_exp_schedule():
|
||||
"""Test RTCConfig with EXP attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
|
||||
|
||||
def test_rtc_config_with_zeros_schedule():
|
||||
"""Test RTCConfig with ZEROS attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS
|
||||
|
||||
|
||||
def test_rtc_config_with_ones_schedule():
|
||||
"""Test RTCConfig with ONES attention schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES
|
||||
|
||||
|
||||
# ====================== Enabled/Disabled Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_enabled_true():
|
||||
"""Test RTCConfig with enabled=True."""
|
||||
config = RTCConfig(enabled=True)
|
||||
assert config.enabled is True
|
||||
|
||||
|
||||
def test_rtc_config_enabled_false():
|
||||
"""Test RTCConfig with enabled=False."""
|
||||
config = RTCConfig(enabled=False)
|
||||
assert config.enabled is False
|
||||
|
||||
|
||||
# ====================== Debug Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_debug_enabled():
|
||||
"""Test RTCConfig with debug enabled."""
|
||||
config = RTCConfig(debug=True, debug_maxlen=500)
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 500
|
||||
|
||||
|
||||
def test_rtc_config_debug_disabled():
|
||||
"""Test RTCConfig with debug disabled."""
|
||||
config = RTCConfig(debug=False)
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
# ====================== Execution Horizon Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_with_small_execution_horizon():
|
||||
"""Test RTCConfig with small execution horizon."""
|
||||
config = RTCConfig(execution_horizon=1)
|
||||
assert config.execution_horizon == 1
|
||||
|
||||
|
||||
def test_rtc_config_with_large_execution_horizon():
|
||||
"""Test RTCConfig with large execution horizon."""
|
||||
config = RTCConfig(execution_horizon=100)
|
||||
assert config.execution_horizon == 100
|
||||
|
||||
|
||||
def test_rtc_config_with_zero_execution_horizon():
|
||||
"""Test RTCConfig accepts zero execution horizon."""
|
||||
# No validation on execution_horizon, so this should work
|
||||
config = RTCConfig(execution_horizon=0)
|
||||
assert config.execution_horizon == 0
|
||||
|
||||
|
||||
def test_rtc_config_with_negative_execution_horizon():
|
||||
"""Test RTCConfig accepts negative execution horizon."""
|
||||
# No validation on execution_horizon, so this should work
|
||||
config = RTCConfig(execution_horizon=-1)
|
||||
assert config.execution_horizon == -1
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_typical_production_settings():
|
||||
"""Test RTCConfig with typical production settings."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=8,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 8
|
||||
assert config.debug is False
|
||||
|
||||
|
||||
def test_rtc_config_typical_debug_settings():
|
||||
"""Test RTCConfig with typical debug settings."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
debug_maxlen=1000,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 1000
|
||||
|
||||
|
||||
def test_rtc_config_disabled_mode():
|
||||
"""Test RTCConfig in disabled mode."""
|
||||
config = RTCConfig(enabled=False)
|
||||
|
||||
assert config.enabled is False
|
||||
# Other settings still accessible even when disabled
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
|
||||
|
||||
# ====================== Dataclass Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_is_dataclass():
|
||||
"""Test that RTCConfig is a dataclass."""
|
||||
from dataclasses import is_dataclass
|
||||
|
||||
assert is_dataclass(RTCConfig)
|
||||
|
||||
|
||||
def test_rtc_config_equality():
|
||||
"""Test RTCConfig equality comparison."""
|
||||
config1 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
||||
config2 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
||||
config3 = RTCConfig(enabled=False, max_guidance_weight=5.0)
|
||||
|
||||
assert config1 == config2
|
||||
assert config1 != config3
|
||||
|
||||
|
||||
def test_rtc_config_repr():
|
||||
"""Test RTCConfig string representation."""
|
||||
config = RTCConfig(enabled=True, execution_horizon=20)
|
||||
repr_str = repr(config)
|
||||
|
||||
assert "RTCConfig" in repr_str
|
||||
assert "enabled=True" in repr_str
|
||||
assert "execution_horizon=20" in repr_str
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_very_small_max_guidance_weight():
|
||||
"""Test RTCConfig with very small positive max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=1e-10)
|
||||
assert config.max_guidance_weight == pytest.approx(1e-10)
|
||||
|
||||
|
||||
def test_rtc_config_very_large_max_guidance_weight():
|
||||
"""Test RTCConfig with very large max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=1e10)
|
||||
assert config.max_guidance_weight == pytest.approx(1e10)
|
||||
|
||||
|
||||
def test_rtc_config_minimum_debug_maxlen():
|
||||
"""Test RTCConfig with minimum valid debug_maxlen."""
|
||||
config = RTCConfig(debug_maxlen=1)
|
||||
assert config.debug_maxlen == 1
|
||||
|
||||
|
||||
def test_rtc_config_float_max_guidance_weight():
|
||||
"""Test RTCConfig with float max_guidance_weight."""
|
||||
config = RTCConfig(max_guidance_weight=3.14159)
|
||||
assert config.max_guidance_weight == pytest.approx(3.14159)
|
||||
|
||||
|
||||
# ====================== Type Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_enabled_type():
|
||||
"""Test RTCConfig enabled field accepts boolean."""
|
||||
config = RTCConfig(enabled=True)
|
||||
assert isinstance(config.enabled, bool)
|
||||
|
||||
|
||||
def test_rtc_config_execution_horizon_type():
|
||||
"""Test RTCConfig execution_horizon field accepts integer."""
|
||||
config = RTCConfig(execution_horizon=15)
|
||||
assert isinstance(config.execution_horizon, int)
|
||||
|
||||
|
||||
def test_rtc_config_max_guidance_weight_type():
|
||||
"""Test RTCConfig max_guidance_weight field accepts float."""
|
||||
config = RTCConfig(max_guidance_weight=7.5)
|
||||
assert isinstance(config.max_guidance_weight, float)
|
||||
|
||||
|
||||
def test_rtc_config_debug_maxlen_type():
|
||||
"""Test RTCConfig debug_maxlen field accepts integer."""
|
||||
config = RTCConfig(debug_maxlen=200)
|
||||
assert isinstance(config.debug_maxlen, int)
|
||||
|
||||
@@ -1,427 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC debug visualizer module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_axes():
|
||||
"""Create mock matplotlib axes."""
|
||||
axes = []
|
||||
for _ in range(6):
|
||||
ax = MagicMock()
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
||||
axes.append(ax)
|
||||
return axes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensor_2d():
|
||||
"""Create a 2D sample tensor (time_steps, num_dims)."""
|
||||
return torch.randn(50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensor_3d():
|
||||
"""Create a 3D sample tensor (batch, time_steps, num_dims)."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_numpy_2d():
|
||||
"""Create a 2D numpy array."""
|
||||
return np.random.randn(50, 6)
|
||||
|
||||
|
||||
# ====================== Basic Plotting Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with 2D tensor."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
# Should call plot on each axis (6 dimensions)
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d):
|
||||
"""Test plot_waypoints with 3D tensor (batch dimension)."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d)
|
||||
|
||||
# Should still plot 6 dimensions (batch dimension removed)
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d):
|
||||
"""Test plot_waypoints with numpy array."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d)
|
||||
|
||||
# Should work with numpy arrays
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_none_tensor(mock_axes):
|
||||
"""Test plot_waypoints returns early when tensor is None."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, None)
|
||||
|
||||
# Should not call plot on any axis
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Parameter Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom color."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red")
|
||||
|
||||
# Check that color was passed to plot
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["color"] == "red"
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
||||
|
||||
# First axis should have label, others should not
|
||||
first_ax_kwargs = mock_axes[0].plot.call_args[1]
|
||||
assert first_ax_kwargs["label"] == "test_label"
|
||||
|
||||
# Other axes should have empty label
|
||||
for ax in mock_axes[1:]:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["label"] == ""
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom alpha."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["alpha"] == 0.5
|
||||
|
||||
|
||||
def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints uses custom linewidth."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["linewidth"] == 3
|
||||
|
||||
|
||||
def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with marker style."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5)
|
||||
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert call_kwargs["marker"] == "o"
|
||||
assert call_kwargs["markersize"] == 5
|
||||
|
||||
|
||||
def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints without marker (default)."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None)
|
||||
|
||||
# Marker should not be in kwargs when None
|
||||
for ax in mock_axes:
|
||||
call_kwargs = ax.plot.call_args[1]
|
||||
assert "marker" not in call_kwargs
|
||||
assert "markersize" not in call_kwargs
|
||||
|
||||
|
||||
# ====================== start_from Parameter Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with start_from=0."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
||||
|
||||
# X indices should start from 0
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert x_indices[0] == 0
|
||||
assert len(x_indices) == 50
|
||||
|
||||
|
||||
def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with start_from > 0."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10)
|
||||
|
||||
# X indices should start from 10
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert x_indices[0] == 10
|
||||
assert x_indices[-1] == 59 # 10 + 50 - 1
|
||||
|
||||
|
||||
# ====================== Tensor Shape Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_1d_tensor(mock_axes):
|
||||
"""Test plot_waypoints with 1D tensor."""
|
||||
tensor_1d = torch.randn(50)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d)
|
||||
|
||||
# Should reshape to (50, 1) and plot on first axis only
|
||||
mock_axes[0].plot.assert_called_once()
|
||||
for ax in mock_axes[1:]:
|
||||
ax.plot.assert_not_called()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d):
|
||||
"""Test plot_waypoints when tensor has fewer dims than axes."""
|
||||
# Create tensor with only 3 dimensions
|
||||
tensor_3d = sample_tensor_2d[:, :3]
|
||||
|
||||
# Create 6 axes but tensor only has 3 dims
|
||||
mock_axes = [MagicMock() for _ in range(6)]
|
||||
for ax in mock_axes:
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d)
|
||||
|
||||
# Should only plot on first 3 axes
|
||||
for i in range(3):
|
||||
mock_axes[i].plot.assert_called_once()
|
||||
for i in range(3, 6):
|
||||
mock_axes[i].plot.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Axis Labeling Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints sets x-axis label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for ax in mock_axes:
|
||||
ax.set_xlabel.assert_called_once_with("Step", fontsize=10)
|
||||
|
||||
|
||||
def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints sets y-axis label."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for i, ax in enumerate(mock_axes):
|
||||
ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10)
|
||||
|
||||
|
||||
def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d):
|
||||
"""Test plot_waypoints doesn't set labels if they already exist."""
|
||||
mock_axes_with_labels = []
|
||||
for _ in range(6):
|
||||
ax = MagicMock()
|
||||
# Simulate existing labels
|
||||
ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label"
|
||||
ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label"
|
||||
mock_axes_with_labels.append(ax)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d)
|
||||
|
||||
# Should not set labels when they already exist
|
||||
for ax in mock_axes_with_labels:
|
||||
ax.set_xlabel.assert_not_called()
|
||||
ax.set_ylabel.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Grid Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints enables grid on all axes."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
||||
|
||||
for ax in mock_axes:
|
||||
ax.grid.assert_called_once_with(True, alpha=0.3)
|
||||
|
||||
|
||||
# ====================== Legend Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints adds legend when label is provided."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
||||
|
||||
# Should add legend to first axis only
|
||||
mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8)
|
||||
|
||||
# Should not add legend to other axes
|
||||
for ax in mock_axes[1:]:
|
||||
ax.legend.assert_not_called()
|
||||
|
||||
|
||||
def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints doesn't add legend when no label provided."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="")
|
||||
|
||||
# Should not add legend to any axis
|
||||
for ax in mock_axes:
|
||||
ax.legend.assert_not_called()
|
||||
|
||||
|
||||
# ====================== Data Correctness Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints plots correct tensor values."""
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
||||
|
||||
# Check first axis to verify data correctness
|
||||
call_args = mock_axes[0].plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
y_values = call_args[1]
|
||||
|
||||
# X indices should be 0 to 49
|
||||
np.testing.assert_array_equal(x_indices, np.arange(50))
|
||||
|
||||
# Y values should match first dimension of tensor
|
||||
expected_y = sample_tensor_2d[:, 0].cpu().numpy()
|
||||
np.testing.assert_array_almost_equal(y_values, expected_y)
|
||||
|
||||
|
||||
def test_plot_waypoints_handles_gpu_tensor(mock_axes):
|
||||
"""Test plot_waypoints handles GPU tensors (if CUDA available)."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
tensor_gpu = torch.randn(50, 6, device="cuda")
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu)
|
||||
|
||||
# Should successfully plot without errors
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_with_empty_tensor(mock_axes):
|
||||
"""Test plot_waypoints with empty tensor."""
|
||||
empty_tensor = torch.empty(0, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor)
|
||||
|
||||
# Should plot empty data
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 0
|
||||
|
||||
|
||||
def test_plot_waypoints_with_single_timestep(mock_axes):
|
||||
"""Test plot_waypoints with single timestep."""
|
||||
single_step_tensor = torch.randn(1, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor)
|
||||
|
||||
# Should plot single point
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 1
|
||||
|
||||
|
||||
def test_plot_waypoints_with_very_large_tensor(mock_axes):
|
||||
"""Test plot_waypoints with very large tensor."""
|
||||
large_tensor = torch.randn(10000, 6)
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor)
|
||||
|
||||
# Should handle large tensors
|
||||
for ax in mock_axes:
|
||||
call_args = ax.plot.call_args[0]
|
||||
x_indices = call_args[0]
|
||||
assert len(x_indices) == 10000
|
||||
|
||||
|
||||
# ====================== Multiple Calls Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d):
|
||||
"""Test multiple plot_waypoints calls on same axes."""
|
||||
tensor1 = sample_tensor_2d
|
||||
tensor2 = torch.randn(50, 6)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1")
|
||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2")
|
||||
|
||||
# Each axis should have been called twice
|
||||
for ax in mock_axes:
|
||||
assert ax.plot.call_count == 2
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with typical usage pattern."""
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2
|
||||
)
|
||||
|
||||
# Verify all expected calls were made
|
||||
for ax in mock_axes:
|
||||
ax.plot.assert_called_once()
|
||||
ax.set_xlabel.assert_called_once()
|
||||
ax.set_ylabel.assert_called_once()
|
||||
ax.grid.assert_called_once()
|
||||
|
||||
# First axis should have legend
|
||||
mock_axes[0].legend.assert_called_once()
|
||||
|
||||
|
||||
def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d):
|
||||
"""Test plot_waypoints with all parameters specified."""
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axes=mock_axes,
|
||||
tensor=sample_tensor_2d,
|
||||
start_from=10,
|
||||
color="green",
|
||||
label="Full Test",
|
||||
alpha=0.8,
|
||||
linewidth=3,
|
||||
marker="o",
|
||||
markersize=6,
|
||||
)
|
||||
|
||||
# Check first axis for all parameters
|
||||
call_kwargs = mock_axes[0].plot.call_args[1]
|
||||
assert call_kwargs["color"] == "green"
|
||||
assert call_kwargs["label"] == "Full Test"
|
||||
assert call_kwargs["alpha"] == 0.8
|
||||
assert call_kwargs["linewidth"] == 3
|
||||
assert call_kwargs["marker"] == "o"
|
||||
assert call_kwargs["markersize"] == 6
|
||||
@@ -184,74 +184,6 @@ def test_max_after_reset(tracker):
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
# ====================== percentile() Tests ======================
|
||||
|
||||
|
||||
def test_percentile_returns_zero_when_empty(tracker):
|
||||
"""Test percentile() returns 0.0 when tracker is empty."""
|
||||
assert tracker.percentile(0.5) == 0.0
|
||||
assert tracker.percentile(0.95) == 0.0
|
||||
|
||||
|
||||
def test_percentile_median(tracker):
|
||||
"""Test percentile(0.5) returns median."""
|
||||
# Add sorted values for easier verification
|
||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
for v in values:
|
||||
tracker.add(v)
|
||||
|
||||
# Median should be around 0.5
|
||||
median = tracker.percentile(0.5)
|
||||
assert 0.45 <= median <= 0.55
|
||||
|
||||
|
||||
def test_percentile_minimum_with_zero(tracker):
|
||||
"""Test percentile(0.0) returns minimum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(0.0) == 0.2
|
||||
|
||||
|
||||
def test_percentile_maximum_with_one(tracker):
|
||||
"""Test percentile(1.0) returns maximum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(1.0) == 0.8
|
||||
|
||||
|
||||
def test_percentile_95(tracker):
|
||||
"""Test percentile(0.95) returns 95th percentile."""
|
||||
# Add 100 values from 0.0 to 0.99
|
||||
for i in range(100):
|
||||
tracker.add(i / 100.0)
|
||||
|
||||
p95 = tracker.percentile(0.95)
|
||||
# 95th percentile should be around 0.95
|
||||
assert 0.93 <= p95 <= 0.96
|
||||
|
||||
|
||||
def test_percentile_negative_value_returns_min(tracker):
|
||||
"""Test percentile with negative q returns minimum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(-0.5) == 0.2
|
||||
|
||||
|
||||
def test_percentile_value_greater_than_one_returns_max(tracker):
|
||||
"""Test percentile with q > 1.0 returns maximum."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.2)
|
||||
tracker.add(0.8)
|
||||
|
||||
assert tracker.percentile(1.5) == 0.8
|
||||
|
||||
|
||||
# ====================== p95() Tests ======================
|
||||
|
||||
|
||||
@@ -278,79 +210,6 @@ def test_p95_equals_percentile_95(tracker):
|
||||
assert tracker.p95() == tracker.percentile(0.95)
|
||||
|
||||
|
||||
# ====================== __len__() Tests ======================
|
||||
|
||||
|
||||
def test_len_returns_zero_initially(tracker):
|
||||
"""Test __len__ returns 0 for new tracker."""
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
def test_len_increments_with_add(tracker):
|
||||
"""Test __len__ increments as values are added."""
|
||||
assert len(tracker) == 0
|
||||
|
||||
tracker.add(0.1)
|
||||
assert len(tracker) == 1
|
||||
|
||||
tracker.add(0.2)
|
||||
assert len(tracker) == 2
|
||||
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 3
|
||||
|
||||
|
||||
def test_len_respects_maxlen(small_tracker):
|
||||
"""Test __len__ respects maxlen limit."""
|
||||
# Add more than maxlen values
|
||||
for i in range(10):
|
||||
small_tracker.add(i / 10.0)
|
||||
|
||||
# Should only keep last 5
|
||||
assert len(small_tracker) == 5
|
||||
|
||||
|
||||
def test_len_after_reset(tracker):
|
||||
"""Test __len__ returns 0 after reset."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 2
|
||||
|
||||
tracker.reset()
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
# ====================== Sliding Window Tests ======================
|
||||
|
||||
|
||||
def test_sliding_window_removes_oldest(small_tracker):
|
||||
"""Test sliding window removes oldest values."""
|
||||
# Add 7 values to tracker with maxlen=5
|
||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
for v in values:
|
||||
small_tracker.add(v)
|
||||
|
||||
# Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
assert len(small_tracker) == 5
|
||||
|
||||
# Median should reflect last 5 values
|
||||
median = small_tracker.percentile(0.5)
|
||||
assert 0.45 <= median <= 0.55
|
||||
|
||||
|
||||
def test_sliding_window_maintains_max(small_tracker):
|
||||
"""Test sliding window maintains correct max even after overflow."""
|
||||
small_tracker.add(0.1)
|
||||
small_tracker.add(0.9)
|
||||
small_tracker.add(0.2)
|
||||
small_tracker.add(0.3)
|
||||
small_tracker.add(0.4)
|
||||
small_tracker.add(0.5) # Pushes out 0.1
|
||||
|
||||
# Max should still be 0.9
|
||||
assert small_tracker.max() == 0.9
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
@@ -436,24 +295,6 @@ def test_reset_and_reuse(tracker):
|
||||
assert tracker.percentile(0.5) <= 0.8
|
||||
|
||||
|
||||
def test_continuous_monitoring(small_tracker):
|
||||
"""Test continuous monitoring with sliding window."""
|
||||
# Simulate continuous latency monitoring
|
||||
# First 5 latencies
|
||||
for i in range(5):
|
||||
small_tracker.add(0.1 * (i + 1))
|
||||
|
||||
max_before = small_tracker.max()
|
||||
|
||||
# Add 5 more (window slides)
|
||||
for i in range(5, 10):
|
||||
small_tracker.add(0.1 * (i + 1))
|
||||
|
||||
# Max should have increased
|
||||
assert small_tracker.max() > max_before
|
||||
assert len(small_tracker) == 5 # Window size maintained
|
||||
|
||||
|
||||
# ====================== Type Conversion Tests ======================
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user