Fix tests

This commit is contained in:
Eugene Mironov
2025-11-11 00:00:01 +07:00
parent 6db3afca6f
commit 6b6c0623cc
4 changed files with 36 additions and 62 deletions
+1 -1
View File
@@ -18,7 +18,7 @@ Usage:
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=ONES \
--rtc.prefix_attention_schedule=EXP \
--seed=10
# Basic usage with pi0.5 policy
Binary file not shown.

Before

Width:  |  Height:  |  Size: 538 KiB

After

Width:  |  Height:  |  Size: 1.3 MiB

-5
View File
@@ -217,11 +217,6 @@ class RTCProcessor:
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
# Explicitly nullify correction after execution horizon to ensure exact match with no-RTC
# Create a mask that zeros out correction after execution_horizon
correction_mask = weights.clone() # weights already have zeros after execution_horizon
correction = correction * correction_mask
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
+35 -56
View File
@@ -200,22 +200,20 @@ def test_get_prefix_weights_linear_schedule():
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section (5:15) should be linearly decreasing from 1 to 0
middle_weights = weights[5:15]
assert middle_weights[0] > middle_weights[-1] # Decreasing
assert torch.all(middle_weights >= 0.0)
assert torch.all(middle_weights <= 1.0)
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
assert torch.allclose(weights[5:14], middle_weights)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[15:] == 0.0)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_exp_schedule():
@@ -223,21 +221,20 @@ def test_get_prefix_weights_exp_schedule():
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
weights = processor.get_prefix_weights(start=5, end=14, total=25)
# Should have shape (20,)
assert weights.shape == (20,)
assert weights.shape == (25,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section should be exponentially weighted
middle_weights = weights[5:15]
assert torch.all(middle_weights >= 0.0)
assert torch.all(middle_weights <= 1.0)
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[15:] == 0.0)
assert torch.all(weights[14:] == 0.0)
def test_get_prefix_weights_with_start_equals_end():
@@ -268,22 +265,6 @@ def test_get_prefix_weights_with_start_greater_than_end():
# ====================== Helper Method Tests ======================
def test_linweights_normal_case():
"""Test _linweights with normal parameters."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=5, end=15, total=20)
# Should create linear weights from 1 to 0
# Excluding the endpoints: linspace(1, 0, steps+2)[1:-1]
# Steps = total - (total - end) - start = 20 - 5 - 5 = 10
assert len(weights) == 10
assert weights[0] < 1.0 # First value after 1.0
assert weights[-1] > 0.0 # Last value before 0.0
assert torch.all(weights[:-1] >= weights[1:]) # Decreasing
def test_linweights_with_end_equals_start():
"""Test _linweights when end equals start."""
config = RTCConfig()
@@ -358,6 +339,31 @@ def test_add_leading_ones_no_ones_needed():
assert torch.equal(result, weights)
def test_get_prefix_weights_with_start_equals_total():
"""Test get_prefix_weights when start equals total."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert len(weights) == 20
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_total_less_than_start():
"""Test get_prefix_weights when total less than start."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=5)
# Should have ones up to start, then zeros
assert len(weights) == 5
assert torch.all(weights == 1.0)
# ====================== denoise_step Tests ======================
@@ -700,33 +706,6 @@ def test_denoise_step_full_workflow():
assert len(steps) == 1
def test_get_prefix_weights_integration():
"""Test get_prefix_weights produces expected structure for all schedules."""
schedules = [
RTCAttentionSchedule.ZEROS,
RTCAttentionSchedule.ONES,
RTCAttentionSchedule.LINEAR,
RTCAttentionSchedule.EXP,
]
for schedule in schedules:
config = RTCConfig(prefix_attention_schedule=schedule)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# All should have correct shape
assert weights.shape == (20,)
# All should be in valid range [0, 1]
assert torch.all(weights >= 0.0)
assert torch.all(weights <= 1.0)
# All should have no NaN or Inf
assert not torch.any(torch.isnan(weights))
assert not torch.any(torch.isinf(weights))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_denoise_step_with_cuda_tensors():
"""Test denoise_step works with CUDA tensors."""