diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index ee801a575..c8e1fecb4 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -18,7 +18,7 @@ Usage: --rtc.execution_horizon=8 \ --device=mps \ --rtc.max_guidance_weight=10.0 \ - --rtc.prefix_attention_schedule=ONES \ + --rtc.prefix_attention_schedule=EXP \ --seed=10 # Basic usage with pi0.5 policy diff --git a/src/lerobot/policies/rtc/flow_matching.png b/src/lerobot/policies/rtc/flow_matching.png index af7c7bf50..173ae7001 100644 Binary files a/src/lerobot/policies/rtc/flow_matching.png and b/src/lerobot/policies/rtc/flow_matching.png differ diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 6a02aa3e8..280905adf 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -217,11 +217,6 @@ class RTCProcessor: grad_outputs = err.clone().detach() correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] - # Explicitly nullify correction after execution horizon to ensure exact match with no-RTC - # Create a mask that zeros out correction after execution_horizon - correction_mask = weights.clone() # weights already have zeros after execution_horizon - correction = correction * correction_mask - max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) tau_tensor = torch.as_tensor(tau) squared_one_minus_tau = (1 - tau_tensor) ** 2 diff --git a/tests/policies/rtc/test_modeling_rtc.py b/tests/policies/rtc/test_modeling_rtc.py index 52940c6d3..f42822bb6 100644 --- a/tests/policies/rtc/test_modeling_rtc.py +++ b/tests/policies/rtc/test_modeling_rtc.py @@ -200,22 +200,20 @@ def test_get_prefix_weights_linear_schedule(): config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) processor = RTCProcessor(config) - weights = processor.get_prefix_weights(start=5, end=15, total=20) + weights = processor.get_prefix_weights(start=5, end=14, total=25) # Should have shape (20,) - assert weights.shape == (20,) + assert weights.shape == (25,) # First 5 should be 1.0 (leading ones) assert torch.all(weights[:5] == 1.0) # Middle section (5:15) should be linearly decreasing from 1 to 0 - middle_weights = weights[5:15] - assert middle_weights[0] > middle_weights[-1] # Decreasing - assert torch.all(middle_weights >= 0.0) - assert torch.all(middle_weights <= 1.0) + middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + assert torch.allclose(weights[5:14], middle_weights) # Last 5 should be 0.0 (trailing zeros) - assert torch.all(weights[15:] == 0.0) + assert torch.all(weights[14:] == 0.0) def test_get_prefix_weights_exp_schedule(): @@ -223,21 +221,20 @@ def test_get_prefix_weights_exp_schedule(): config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP) processor = RTCProcessor(config) - weights = processor.get_prefix_weights(start=5, end=15, total=20) + weights = processor.get_prefix_weights(start=5, end=14, total=25) # Should have shape (20,) - assert weights.shape == (20,) + assert weights.shape == (25,) # First 5 should be 1.0 (leading ones) assert torch.all(weights[:5] == 1.0) # Middle section should be exponentially weighted - middle_weights = weights[5:15] - assert torch.all(middle_weights >= 0.0) - assert torch.all(middle_weights <= 1.0) + middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061]) + assert torch.allclose(weights[5:14], middle_weights, atol=1e-4) # Last 5 should be 0.0 (trailing zeros) - assert torch.all(weights[15:] == 0.0) + assert torch.all(weights[14:] == 0.0) def test_get_prefix_weights_with_start_equals_end(): @@ -268,22 +265,6 @@ def test_get_prefix_weights_with_start_greater_than_end(): # ====================== Helper Method Tests ====================== -def test_linweights_normal_case(): - """Test _linweights with normal parameters.""" - config = RTCConfig() - processor = RTCProcessor(config) - - weights = processor._linweights(start=5, end=15, total=20) - - # Should create linear weights from 1 to 0 - # Excluding the endpoints: linspace(1, 0, steps+2)[1:-1] - # Steps = total - (total - end) - start = 20 - 5 - 5 = 10 - assert len(weights) == 10 - assert weights[0] < 1.0 # First value after 1.0 - assert weights[-1] > 0.0 # Last value before 0.0 - assert torch.all(weights[:-1] >= weights[1:]) # Decreasing - - def test_linweights_with_end_equals_start(): """Test _linweights when end equals start.""" config = RTCConfig() @@ -358,6 +339,31 @@ def test_add_leading_ones_no_ones_needed(): assert torch.equal(result, weights) +def test_get_prefix_weights_with_start_equals_total(): + """Test get_prefix_weights when start equals total.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=10, end=10, total=20) + + # Should have ones up to start, then zeros + assert len(weights) == 20 + assert torch.all(weights[:10] == 1.0) + assert torch.all(weights[10:] == 0.0) + + +def test_get_prefix_weights_with_total_less_than_start(): + """Test get_prefix_weights when total less than start.""" + config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) + processor = RTCProcessor(config) + + weights = processor.get_prefix_weights(start=10, end=10, total=5) + + # Should have ones up to start, then zeros + assert len(weights) == 5 + assert torch.all(weights == 1.0) + + # ====================== denoise_step Tests ====================== @@ -700,33 +706,6 @@ def test_denoise_step_full_workflow(): assert len(steps) == 1 -def test_get_prefix_weights_integration(): - """Test get_prefix_weights produces expected structure for all schedules.""" - schedules = [ - RTCAttentionSchedule.ZEROS, - RTCAttentionSchedule.ONES, - RTCAttentionSchedule.LINEAR, - RTCAttentionSchedule.EXP, - ] - - for schedule in schedules: - config = RTCConfig(prefix_attention_schedule=schedule) - processor = RTCProcessor(config) - - weights = processor.get_prefix_weights(start=5, end=15, total=20) - - # All should have correct shape - assert weights.shape == (20,) - - # All should be in valid range [0, 1] - assert torch.all(weights >= 0.0) - assert torch.all(weights <= 1.0) - - # All should have no NaN or Inf - assert not torch.any(torch.isnan(weights)) - assert not torch.any(torch.isinf(weights)) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_denoise_step_with_cuda_tensors(): """Test denoise_step works with CUDA tensors."""