mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Fix tests
This commit is contained in:
@@ -18,7 +18,7 @@ Usage:
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--rtc.prefix_attention_schedule=ONES \
|
||||
--rtc.prefix_attention_schedule=EXP \
|
||||
--seed=10
|
||||
|
||||
# Basic usage with pi0.5 policy
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 538 KiB After Width: | Height: | Size: 1.3 MiB |
@@ -217,11 +217,6 @@ class RTCProcessor:
|
||||
grad_outputs = err.clone().detach()
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
# Explicitly nullify correction after execution horizon to ensure exact match with no-RTC
|
||||
# Create a mask that zeros out correction after execution_horizon
|
||||
correction_mask = weights.clone() # weights already have zeros after execution_horizon
|
||||
correction = correction * correction_mask
|
||||
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
|
||||
@@ -200,22 +200,20 @@ def test_get_prefix_weights_linear_schedule():
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section (5:15) should be linearly decreasing from 1 to 0
|
||||
middle_weights = weights[5:15]
|
||||
assert middle_weights[0] > middle_weights[-1] # Decreasing
|
||||
assert torch.all(middle_weights >= 0.0)
|
||||
assert torch.all(middle_weights <= 1.0)
|
||||
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
assert torch.allclose(weights[5:14], middle_weights)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_exp_schedule():
|
||||
@@ -223,21 +221,20 @@ def test_get_prefix_weights_exp_schedule():
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section should be exponentially weighted
|
||||
middle_weights = weights[5:15]
|
||||
assert torch.all(middle_weights >= 0.0)
|
||||
assert torch.all(middle_weights <= 1.0)
|
||||
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
|
||||
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_end():
|
||||
@@ -268,22 +265,6 @@ def test_get_prefix_weights_with_start_greater_than_end():
|
||||
# ====================== Helper Method Tests ======================
|
||||
|
||||
|
||||
def test_linweights_normal_case():
|
||||
"""Test _linweights with normal parameters."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=5, end=15, total=20)
|
||||
|
||||
# Should create linear weights from 1 to 0
|
||||
# Excluding the endpoints: linspace(1, 0, steps+2)[1:-1]
|
||||
# Steps = total - (total - end) - start = 20 - 5 - 5 = 10
|
||||
assert len(weights) == 10
|
||||
assert weights[0] < 1.0 # First value after 1.0
|
||||
assert weights[-1] > 0.0 # Last value before 0.0
|
||||
assert torch.all(weights[:-1] >= weights[1:]) # Decreasing
|
||||
|
||||
|
||||
def test_linweights_with_end_equals_start():
|
||||
"""Test _linweights when end equals start."""
|
||||
config = RTCConfig()
|
||||
@@ -358,6 +339,31 @@ def test_add_leading_ones_no_ones_needed():
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_total():
|
||||
"""Test get_prefix_weights when start equals total."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 20
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_total_less_than_start():
|
||||
"""Test get_prefix_weights when total less than start."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=5)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 5
|
||||
assert torch.all(weights == 1.0)
|
||||
|
||||
|
||||
# ====================== denoise_step Tests ======================
|
||||
|
||||
|
||||
@@ -700,33 +706,6 @@ def test_denoise_step_full_workflow():
|
||||
assert len(steps) == 1
|
||||
|
||||
|
||||
def test_get_prefix_weights_integration():
|
||||
"""Test get_prefix_weights produces expected structure for all schedules."""
|
||||
schedules = [
|
||||
RTCAttentionSchedule.ZEROS,
|
||||
RTCAttentionSchedule.ONES,
|
||||
RTCAttentionSchedule.LINEAR,
|
||||
RTCAttentionSchedule.EXP,
|
||||
]
|
||||
|
||||
for schedule in schedules:
|
||||
config = RTCConfig(prefix_attention_schedule=schedule)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# All should have correct shape
|
||||
assert weights.shape == (20,)
|
||||
|
||||
# All should be in valid range [0, 1]
|
||||
assert torch.all(weights >= 0.0)
|
||||
assert torch.all(weights <= 1.0)
|
||||
|
||||
# All should have no NaN or Inf
|
||||
assert not torch.any(torch.isnan(weights))
|
||||
assert not torch.any(torch.isinf(weights))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_denoise_step_with_cuda_tensors():
|
||||
"""Test denoise_step works with CUDA tensors."""
|
||||
|
||||
Reference in New Issue
Block a user