Files
lerobot/tests/policies/rtc/test_modeling_rtc.py
T
Eugene Mironov 99eea2ae03 Add more tests
2025-11-15 00:09:01 +07:00

785 lines
24 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC modeling module (RTCProcessor)."""
import pytest
import torch
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
# ====================== Fixtures ======================
@pytest.fixture
def rtc_config_debug_enabled():
"""Create RTC config with debug enabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=True,
debug_maxlen=100,
)
@pytest.fixture
def rtc_config_debug_disabled():
"""Create RTC config with debug disabled."""
return RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=10.0,
execution_horizon=10,
debug=False,
)
@pytest.fixture
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
"""Create RTCProcessor with debug enabled."""
return RTCProcessor(rtc_config_debug_enabled)
@pytest.fixture
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
"""Create RTCProcessor with debug disabled."""
return RTCProcessor(rtc_config_debug_disabled)
@pytest.fixture
def sample_x_t():
"""Create sample x_t tensor (batch, time, action_dim)."""
return torch.randn(1, 50, 6)
@pytest.fixture
def sample_prev_chunk():
"""Create sample previous chunk tensor."""
return torch.randn(1, 50, 6)
# ====================== Initialization Tests ======================
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
"""Test RTCProcessor initializes with debug tracker."""
processor = RTCProcessor(rtc_config_debug_enabled)
assert processor.rtc_config == rtc_config_debug_enabled
assert processor.tracker is not None
assert processor.tracker.enabled is True
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
"""Test RTCProcessor initializes without debug tracker."""
processor = RTCProcessor(rtc_config_debug_disabled)
assert processor.rtc_config == rtc_config_debug_disabled
assert processor.tracker is None
# ====================== Tracker Proxy Methods Tests ======================
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test track() forwards to tracker when enabled."""
rtc_processor_debug_enabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
guidance_weight=2.0,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
assert steps[0].time == 0.5
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
"""Test track() does nothing when tracker disabled."""
# Should not raise error
rtc_processor_debug_disabled.track(
time=torch.tensor(0.5),
x_t=sample_x_t,
v_t=sample_x_t,
)
# Should return empty list
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test get_all_debug_steps() returns tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 2
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
"""Test get_all_debug_steps() returns empty list when disabled."""
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert steps == []
assert isinstance(steps, list)
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
"""Test is_debug_enabled() returns True when tracker enabled."""
assert rtc_processor_debug_enabled.is_debug_enabled() is True
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
"""Test is_debug_enabled() returns False when tracker disabled."""
assert rtc_processor_debug_disabled.is_debug_enabled() is False
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
"""Test reset_tracker() clears tracked steps."""
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
rtc_processor_debug_enabled.reset_tracker()
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
"""Test reset_tracker() doesn't error when tracker disabled."""
rtc_processor_debug_disabled.reset_tracker() # Should not raise
# ====================== get_prefix_weights Tests ======================
def test_get_prefix_weights_zeros_schedule():
"""Test get_prefix_weights with ZEROS schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=10, total=20)
# First 5 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:5] == 1.0)
assert torch.all(weights[5:] == 0.0)
def test_get_prefix_weights_ones_schedule():
"""Test get_prefix_weights with ONES schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# First 15 should be 1.0, rest should be 0.0
assert weights.shape == (20,)
assert torch.all(weights[:15] == 1.0)
assert torch.all(weights[15:] == 0.0)
def test_get_prefix_weights_linear_schedule():
"""Test get_prefix_weights with LINEAR schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# Should have shape (20,)
assert weights.shape == (20,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section (5:15) should be linearly decreasing from 1 to 0
middle_weights = weights[5:15]
assert middle_weights[0] > middle_weights[-1] # Decreasing
assert torch.all(middle_weights >= 0.0)
assert torch.all(middle_weights <= 1.0)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[15:] == 0.0)
def test_get_prefix_weights_exp_schedule():
"""Test get_prefix_weights with EXP schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# Should have shape (20,)
assert weights.shape == (20,)
# First 5 should be 1.0 (leading ones)
assert torch.all(weights[:5] == 1.0)
# Middle section should be exponentially weighted
middle_weights = weights[5:15]
assert torch.all(middle_weights >= 0.0)
assert torch.all(middle_weights <= 1.0)
# Last 5 should be 0.0 (trailing zeros)
assert torch.all(weights[15:] == 0.0)
def test_get_prefix_weights_with_start_equals_end():
"""Test get_prefix_weights when start equals end."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=10, end=10, total=20)
# Should have ones up to start, then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
def test_get_prefix_weights_with_start_greater_than_end():
"""Test get_prefix_weights when start > end (gets clamped)."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
processor = RTCProcessor(config)
# start > end should use min(start, end) = end
weights = processor.get_prefix_weights(start=15, end=10, total=20)
# Should have ones up to end (10), then zeros
assert torch.all(weights[:10] == 1.0)
assert torch.all(weights[10:] == 0.0)
# ====================== Helper Method Tests ======================
def test_linweights_normal_case():
"""Test _linweights with normal parameters."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=5, end=15, total=20)
# Should create linear weights from 1 to 0
# Excluding the endpoints: linspace(1, 0, steps+2)[1:-1]
# Steps = total - (total - end) - start = 20 - 5 - 5 = 10
assert len(weights) == 10
assert weights[0] < 1.0 # First value after 1.0
assert weights[-1] > 0.0 # Last value before 0.0
assert torch.all(weights[:-1] >= weights[1:]) # Decreasing
def test_linweights_with_end_equals_start():
"""Test _linweights when end equals start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=10, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_linweights_with_end_less_than_start():
"""Test _linweights when end < start."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = processor._linweights(start=15, end=10, total=20)
# Should return empty tensor
assert len(weights) == 0
def test_add_trailing_zeros_normal():
"""Test _add_trailing_zeros adds zeros correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
result = processor._add_trailing_zeros(weights, total=10, end=5)
# Should add 5 zeros (total - end = 10 - 5 = 5)
assert len(result) == 10
assert torch.all(result[:5] == weights)
assert torch.all(result[5:] == 0.0)
def test_add_trailing_zeros_no_zeros_needed():
"""Test _add_trailing_zeros when no zeros needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([1.0, 0.8, 0.6])
result = processor._add_trailing_zeros(weights, total=3, end=5)
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
assert torch.equal(result, weights)
def test_add_leading_ones_normal():
"""Test _add_leading_ones adds ones correctly."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
result = processor._add_leading_ones(weights, start=3, total=10)
# Should add 3 ones at the start
assert len(result) == 8
assert torch.all(result[:3] == 1.0)
assert torch.all(result[3:] == weights)
def test_add_leading_ones_no_ones_needed():
"""Test _add_leading_ones when no ones needed."""
config = RTCConfig()
processor = RTCProcessor(config)
weights = torch.tensor([0.8, 0.6, 0.4])
result = processor._add_leading_ones(weights, start=0, total=10)
# ones_len = 0, so no ones added
assert torch.equal(result, weights)
# ====================== denoise_step Tests ======================
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step without previous chunk (no guidance)."""
x_t = torch.randn(1, 50, 6)
# Mock denoiser that returns fixed velocity
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should return v_t unchanged (no guidance)
expected = mock_denoiser(x_t)
assert torch.allclose(result, expected)
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
"""Test denoise_step with previous chunk applies guidance."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Result should be different from base v_t (guidance applied)
base_v_t = mock_denoiser(x_t)
assert not torch.allclose(result, base_v_t)
# Result should have same shape
assert result.shape == x_t.shape
def test_denoise_step_adds_batch_dimension():
"""Test denoise_step handles 2D input by adding batch dimension."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
# 2D input (no batch dimension)
x_t = torch.randn(50, 6)
prev_chunk = torch.randn(50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Output should be 2D (batch dimension removed)
assert result.ndim == 2
assert result.shape == (50, 6)
def test_denoise_step_pads_shorter_prev_chunk():
"""Test denoise_step pads previous chunk if shorter than x_t."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 30, 6) # Shorter than x_t
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should complete successfully (padding happens internally)
assert result.shape == x_t.shape
def test_denoise_step_pads_fewer_action_dims():
"""Test denoise_step pads if prev_chunk has fewer action dimensions."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 4) # Fewer action dims
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should complete successfully (padding happens internally)
assert result.shape == x_t.shape
def test_denoise_step_uses_custom_execution_horizon():
"""Test denoise_step uses custom execution_horizon parameter."""
config = RTCConfig(execution_horizon=10)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Use custom execution_horizon
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
execution_horizon=20, # Override config value
)
assert result.shape == x_t.shape
def test_denoise_step_clamps_execution_horizon_to_prev_chunk_length():
"""Test denoise_step clamps execution_horizon if prev_chunk is shorter."""
config = RTCConfig(execution_horizon=100) # Very large
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 20, 6) # Only 20 timesteps
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Should clamp execution_horizon to 20 internally
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
assert result.shape == x_t.shape
def test_denoise_step_guidance_weight_calculation():
"""Test denoise_step calculates guidance weight correctly."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Time = 0.5 => tau = 1 - 0.5 = 0.5
time = 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=time,
original_denoise_step_partial=mock_denoiser,
)
# Should produce valid output
assert result.shape == x_t.shape
assert not torch.any(torch.isnan(result))
assert not torch.any(torch.isinf(result))
def test_denoise_step_guidance_weight_at_time_zero():
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Time = 0 => tau = 1, c = (1-tau)/tau = 0/1 = 0
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.0),
original_denoise_step_partial=mock_denoiser,
)
# Should handle gracefully (no NaN/Inf)
assert not torch.any(torch.isnan(result))
assert not torch.any(torch.isinf(result))
def test_denoise_step_guidance_weight_at_time_one():
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
config = RTCConfig(max_guidance_weight=10.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(1.0),
original_denoise_step_partial=mock_denoiser,
)
# Should clamp to max_guidance_weight (no Inf)
assert not torch.any(torch.isinf(result))
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
"""Test denoise_step tracks debug information when enabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_enabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should have tracked one step
steps = rtc_processor_debug_enabled.get_all_debug_steps()
assert len(steps) == 1
# Check tracked values
step = steps[0]
assert step.time == 0.5
assert step.x1_t is not None
assert step.correction is not None
assert step.err is not None
assert step.weights is not None
assert step.guidance_weight is not None
assert step.inference_delay == 5
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
"""Test denoise_step doesn't track when debug disabled."""
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
rtc_processor_debug_disabled.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Should not track
steps = rtc_processor_debug_disabled.get_all_debug_steps()
assert len(steps) == 0
# ====================== Integration Tests ======================
def test_denoise_step_full_workflow():
"""Test complete denoise_step workflow."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=5.0,
execution_horizon=10,
debug=True,
)
processor = RTCProcessor(config)
# Simulate two denoising steps
x_t1 = torch.randn(1, 50, 6)
x_t2 = torch.randn(1, 50, 6)
def mock_denoiser(x):
return torch.randn_like(x) * 0.1
# First step - no guidance
result1 = processor.denoise_step(
x_t=x_t1,
prev_chunk_left_over=None,
inference_delay=5,
time=torch.tensor(0.8),
original_denoise_step_partial=mock_denoiser,
)
# Second step - with guidance
result2 = processor.denoise_step(
x_t=x_t2,
prev_chunk_left_over=result1,
inference_delay=5,
time=torch.tensor(0.6),
original_denoise_step_partial=mock_denoiser,
)
# Both should complete successfully
assert result1.shape == (1, 50, 6)
assert result2.shape == (1, 50, 6)
# Should have tracked one step (second one, first had no prev_chunk)
steps = processor.get_all_debug_steps()
assert len(steps) == 1
def test_get_prefix_weights_integration():
"""Test get_prefix_weights produces expected structure for all schedules."""
schedules = [
RTCAttentionSchedule.ZEROS,
RTCAttentionSchedule.ONES,
RTCAttentionSchedule.LINEAR,
RTCAttentionSchedule.EXP,
]
for schedule in schedules:
config = RTCConfig(prefix_attention_schedule=schedule)
processor = RTCProcessor(config)
weights = processor.get_prefix_weights(start=5, end=15, total=20)
# All should have correct shape
assert weights.shape == (20,)
# All should be in valid range [0, 1]
assert torch.all(weights >= 0.0)
assert torch.all(weights <= 1.0)
# All should have no NaN or Inf
assert not torch.any(torch.isnan(weights))
assert not torch.any(torch.isinf(weights))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_denoise_step_with_cuda_tensors():
"""Test denoise_step works with CUDA tensors."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
x_t = torch.randn(1, 50, 6, device="cuda")
prev_chunk = torch.randn(1, 50, 6, device="cuda")
def mock_denoiser(x):
return torch.ones_like(x) * 0.5
result = processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk,
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=mock_denoiser,
)
# Result should be on CUDA
assert result.device.type == "cuda"
assert result.shape == x_t.shape
def test_denoise_step_deterministic_with_same_inputs():
"""Test denoise_step produces same output with same inputs."""
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
processor = RTCProcessor(config)
torch.manual_seed(42)
x_t = torch.randn(1, 50, 6)
prev_chunk = torch.randn(1, 50, 6)
def deterministic_denoiser(x):
return torch.ones_like(x) * 0.5
result1 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
result2 = processor.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_chunk.clone(),
inference_delay=5,
time=torch.tensor(0.5),
original_denoise_step_partial=deterministic_denoiser,
)
# Should produce identical results
assert torch.allclose(result1, result2)