diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py new file mode 100644 index 000000000..cca191024 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test PI0.5 policy with Real-Time Chunking (RTC) enabled during inference.""" + +import os + +import pytest +import torch + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 +from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402 +from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + + +def validate_rtc_behavior( + rtc_actions: torch.Tensor, + no_rtc_actions: torch.Tensor, + prev_chunk: torch.Tensor, + inference_delay: int, + execution_horizon: int, + rtol: float = 1e-2, +): + """Validate RTC behavior follows expected rules. + + Returns: + Tuple of (all_passed, failures) where failures is a list of error messages + """ + # Remove batch dimension if present and move to CPU + rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() + no_rtc_actions_t = ( + no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() + ) + prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() + + chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) + failures = [] + + # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk + if inference_delay > 0: + delay_end = min(inference_delay, chunk_len) + rtc_delay = rtc_actions_t[:delay_end] + prev_delay = prev_chunk_t[:delay_end] + + if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() + failures.append( + f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" + ) + + # Rule 2: Blend region [inference_delay:execution_horizon] + blend_start = inference_delay + blend_end = min(execution_horizon, chunk_len) + + if blend_end > blend_start: + rtc_blend = rtc_actions_t[blend_start:blend_end] + prev_blend = prev_chunk_t[blend_start:blend_end] + no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] + + min_bound = torch.minimum(prev_blend, no_rtc_blend) + max_bound = torch.maximum(prev_blend, no_rtc_blend) + within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) + + if not torch.all(within_bounds): + violations = torch.sum(~within_bounds).item() + total_elements = within_bounds.numel() + failures.append( + f"Blend region [{blend_start}:{blend_end}]: " + f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" + ) + + # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc + if execution_horizon < chunk_len: + rtc_after = rtc_actions_t[execution_horizon:chunk_len] + no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] + + if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() + failures.append( + f"Post-horizon [{execution_horizon}:{chunk_len}]: " + f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" + ) + + return len(failures) == 0, failures + + +@require_cuda +def test_pi05_rtc_initialization(): + """Test PI0.5 policy can initialize RTC processor.""" + set_seed(42) + + config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI05Policy(config) + + # Verify RTC processor is initialized + assert hasattr(policy, "rtc_processor") + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.enabled is True + + print("✓ PI0.5 RTC initialization: Test passed") + + +@require_cuda +def test_pi05_rtc_inference_with_prev_chunk(): + """Test PI0.5 policy inference with RTC and previous chunk.""" + set_seed(42) + + config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI05Policy(config) + policy.eval() + preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC and previous chunk + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Test without RTC for comparison + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Verify shapes + assert actions_with_rtc.shape == (1, config.chunk_size, 7) + assert actions_without_rtc.shape == (1, config.chunk_size, 7) + + # With previous chunk, actions should be different (RTC guidance applied) + assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) + + print("✓ PI0.5 RTC inference with prev_chunk: Test passed") + + +@require_cuda +def test_pi05_rtc_inference_without_prev_chunk(): + """Test PI0.5 policy inference with RTC but no previous chunk (RTC should have no effect).""" + set_seed(42) + + config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI05Policy(config) + policy.eval() + preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC enabled but no previous chunk + actions_with_rtc_no_prev = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=None, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Without previous chunk, RTC should have no effect + assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) + + print("✓ PI0.5 RTC inference without prev_chunk: Test passed") + + +@require_cuda +def test_pi05_rtc_validation_rules(): + """Test PI0.5 policy with RTC follows all three validation rules.""" + set_seed(42) + + config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI05Policy(config) + policy.eval() + preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + inference_delay = 4 + execution_horizon = 10 + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Validate RTC behavior rules + all_passed, failures = validate_rtc_behavior( + rtc_actions=actions_with_rtc, + no_rtc_actions=actions_without_rtc, + prev_chunk=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + if not all_passed: + error_msg = "RTC validation failed:\n" + "\n".join(failures) + pytest.fail(error_msg) + + print("✓ PI0.5 RTC validation rules: All rules passed") + print(" ✓ Delay region [0:4]: RTC = prev_chunk") + print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") + print(" ✓ Post-horizon [10:]: RTC = no_rtc") diff --git a/tests/policies/pi0_pi05/test_pi0_rtc.py b/tests/policies/pi0_pi05/test_pi0_rtc.py new file mode 100644 index 000000000..f76d11ff0 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi0_rtc.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference.""" + +import os + +import pytest +import torch + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 +from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402 +from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + + +def validate_rtc_behavior( + rtc_actions: torch.Tensor, + no_rtc_actions: torch.Tensor, + prev_chunk: torch.Tensor, + inference_delay: int, + execution_horizon: int, + rtol: float = 1e-2, +): + """Validate RTC behavior follows expected rules. + + Returns: + Tuple of (all_passed, failures) where failures is a list of error messages + """ + # Remove batch dimension if present and move to CPU + rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() + no_rtc_actions_t = ( + no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() + ) + prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() + + chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) + failures = [] + + # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk + if inference_delay > 0: + delay_end = min(inference_delay, chunk_len) + rtc_delay = rtc_actions_t[:delay_end] + prev_delay = prev_chunk_t[:delay_end] + + if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() + failures.append( + f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" + ) + + # Rule 2: Blend region [inference_delay:execution_horizon] + blend_start = inference_delay + blend_end = min(execution_horizon, chunk_len) + + if blend_end > blend_start: + rtc_blend = rtc_actions_t[blend_start:blend_end] + prev_blend = prev_chunk_t[blend_start:blend_end] + no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] + + min_bound = torch.minimum(prev_blend, no_rtc_blend) + max_bound = torch.maximum(prev_blend, no_rtc_blend) + within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) + + if not torch.all(within_bounds): + violations = torch.sum(~within_bounds).item() + total_elements = within_bounds.numel() + failures.append( + f"Blend region [{blend_start}:{blend_end}]: " + f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" + ) + + # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc + if execution_horizon < chunk_len: + rtc_after = rtc_actions_t[execution_horizon:chunk_len] + no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] + + if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() + failures.append( + f"Post-horizon [{execution_horizon}:{chunk_len}]: " + f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" + ) + + return len(failures) == 0, failures + + +@require_cuda +def test_pi0_rtc_initialization(): + """Test PI0 policy can initialize RTC processor.""" + set_seed(42) + + config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = PI0Policy(config) + + # Verify RTC processor is initialized + assert hasattr(policy, "rtc_processor") + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.enabled is True + + print("✓ PI0 RTC initialization: Test passed") + + +@require_cuda +def test_pi0_rtc_inference_with_prev_chunk(): + """Test PI0 policy inference with RTC and previous chunk.""" + set_seed(42) + + config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI0Policy(config) + policy.eval() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC and previous chunk + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Test without RTC for comparison + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Verify shapes + assert actions_with_rtc.shape == (1, config.chunk_size, 7) + assert actions_without_rtc.shape == (1, config.chunk_size, 7) + + # With previous chunk, actions should be different (RTC guidance applied) + assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) + + print("✓ PI0 RTC inference with prev_chunk: Test passed") + + +@require_cuda +def test_pi0_rtc_inference_without_prev_chunk(): + """Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect).""" + set_seed(42) + + config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI0Policy(config) + policy.eval() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC enabled but no previous chunk + actions_with_rtc_no_prev = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=None, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Without previous chunk, RTC should have no effect + assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) + + print("✓ PI0 RTC inference without prev_chunk: Test passed") + + +@require_cuda +def test_pi0_rtc_validation_rules(): + """Test PI0 policy with RTC follows all three validation rules.""" + set_seed(42) + + config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and preprocessor + policy = PI0Policy(config) + policy.eval() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + inference_delay = 4 + execution_horizon = 10 + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Validate RTC behavior rules + all_passed, failures = validate_rtc_behavior( + rtc_actions=actions_with_rtc, + no_rtc_actions=actions_without_rtc, + prev_chunk=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + if not all_passed: + error_msg = "RTC validation failed:\n" + "\n".join(failures) + pytest.fail(error_msg) + + print("✓ PI0 RTC validation rules: All rules passed") + print(" ✓ Delay region [0:4]: RTC = prev_chunk") + print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") + print(" ✓ Post-horizon [10:]: RTC = no_rtc") + + """Test PI0 with different RTC attention schedules.""" + set_seed(42) + + schedules = [ + RTCAttentionSchedule.ZEROS, + RTCAttentionSchedule.ONES, + RTCAttentionSchedule.LINEAR, + RTCAttentionSchedule.EXP, + ] + + config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + device = config.device + + for schedule in schedules: + print(f"Testing schedule: {schedule}") + + # Add RTC config with specific schedule + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=schedule, + debug=False, + ) + + # Instantiate policy + policy = PI0Policy(config) + policy.eval() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + actions = policy.predict_action_chunk( + batch, + noise=noise, + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Verify shape + assert actions.shape == (1, config.chunk_size, 7) + print(f" ✓ Schedule {schedule}: Test passed") + + print("✓ PI0 RTC different schedules: All schedules tested") diff --git a/tests/policies/smolvla/test_smolvla_rtc.py b/tests/policies/smolvla/test_smolvla_rtc.py new file mode 100644 index 000000000..1f5a7a524 --- /dev/null +++ b/tests/policies/smolvla/test_smolvla_rtc.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference.""" + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 +from lerobot.policies.factory import make_pre_post_processors # noqa: E402 +from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401 +from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 +from lerobot.utils.random_utils import set_seed # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + + +def validate_rtc_behavior( + rtc_actions: torch.Tensor, + no_rtc_actions: torch.Tensor, + prev_chunk: torch.Tensor, + inference_delay: int, + execution_horizon: int, + rtol: float = 1e-2, +): + """Validate RTC behavior follows expected rules. + + Returns: + Tuple of (all_passed, failures) where failures is a list of error messages + """ + # Remove batch dimension if present and move to CPU + rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() + no_rtc_actions_t = ( + no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() + ) + prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() + + chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) + failures = [] + + # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk + if inference_delay > 0: + delay_end = min(inference_delay, chunk_len) + rtc_delay = rtc_actions_t[:delay_end] + prev_delay = prev_chunk_t[:delay_end] + + if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() + failures.append( + f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" + ) + + # Rule 2: Blend region [inference_delay:execution_horizon] + blend_start = inference_delay + blend_end = min(execution_horizon, chunk_len) + + if blend_end > blend_start: + rtc_blend = rtc_actions_t[blend_start:blend_end] + prev_blend = prev_chunk_t[blend_start:blend_end] + no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] + + min_bound = torch.minimum(prev_blend, no_rtc_blend) + max_bound = torch.maximum(prev_blend, no_rtc_blend) + within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) + + if not torch.all(within_bounds): + violations = torch.sum(~within_bounds).item() + total_elements = within_bounds.numel() + failures.append( + f"Blend region [{blend_start}:{blend_end}]: " + f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" + ) + + # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc + if execution_horizon < chunk_len: + rtc_after = rtc_actions_t[execution_horizon:chunk_len] + no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] + + if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() + failures.append( + f"Post-horizon [{execution_horizon}:{chunk_len}]: " + f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" + ) + + return len(failures) == 0, failures + + +@require_cuda +def test_smolvla_rtc_initialization(): + """Test SmolVLA policy can initialize RTC processor.""" + set_seed(42) + + config = SmolVLAConfig(max_action_dim=7, chunk_size=50) + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Instantiate policy + policy = SmolVLAPolicy(config) + + # Verify RTC processor is initialized + assert hasattr(policy, "rtc_processor") + assert policy.rtc_processor is not None + assert policy.rtc_processor.rtc_config.enabled is True + + print("✓ SmolVLA RTC initialization: Test passed") + + +@require_cuda +@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") +def test_smolvla_rtc_inference_with_prev_chunk(): + """Test SmolVLA policy inference with RTC and previous chunk.""" + set_seed(42) + + config = SmolVLAConfig(max_action_dim=7, chunk_size=50) + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and create preprocessor + policy = SmolVLAPolicy(config) + policy.eval() + preprocessor, _ = make_pre_post_processors( + policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats + ) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC and previous chunk + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Test without RTC for comparison + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Verify shapes + assert actions_with_rtc.shape == (1, config.chunk_size, 7) + assert actions_without_rtc.shape == (1, config.chunk_size, 7) + + # With previous chunk, actions should be different (RTC guidance applied) + assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) + + print("✓ SmolVLA RTC inference with prev_chunk: Test passed") + + +@require_cuda +@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") +def test_smolvla_rtc_inference_without_prev_chunk(): + """Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect).""" + set_seed(42) + + config = SmolVLAConfig(max_action_dim=7, chunk_size=50) + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and create preprocessor + policy = SmolVLAPolicy(config) + policy.eval() + preprocessor, _ = make_pre_post_processors( + policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats + ) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC enabled but no previous chunk + actions_with_rtc_no_prev = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=None, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Without previous chunk, RTC should have no effect + assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) + + print("✓ SmolVLA RTC inference without prev_chunk: Test passed") + + +@require_cuda +@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") +def test_smolvla_rtc_validation_rules(): + """Test SmolVLA policy with RTC follows all three validation rules.""" + set_seed(42) + + config = SmolVLAConfig(max_action_dim=7, chunk_size=50) + + # Add RTC config + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + debug=False, + ) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + # Instantiate policy and create preprocessor + policy = SmolVLAPolicy(config) + policy.eval() + preprocessor, _ = make_pre_post_processors( + policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats + ) + + device = config.device + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + inference_delay = 4 + execution_horizon = 10 + + with torch.no_grad(): + # Use same noise for fair comparison + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + + # Test with RTC + actions_with_rtc = policy.predict_action_chunk( + batch, + noise=noise.clone(), + prev_chunk_left_over=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + # Test without RTC + policy.config.rtc_config.enabled = False + actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) + policy.config.rtc_config.enabled = True + + # Validate RTC behavior rules + all_passed, failures = validate_rtc_behavior( + rtc_actions=actions_with_rtc, + no_rtc_actions=actions_without_rtc, + prev_chunk=prev_chunk, + inference_delay=inference_delay, + execution_horizon=execution_horizon, + ) + + if not all_passed: + error_msg = "RTC validation failed:\n" + "\n".join(failures) + pytest.fail(error_msg) + + print("✓ SmolVLA RTC validation rules: All rules passed") + print(" ✓ Delay region [0:4]: RTC = prev_chunk") + print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") + print(" ✓ Post-horizon [10:]: RTC = no_rtc") + + +@require_cuda +@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") +def test_smolvla_rtc_different_schedules(): + """Test SmolVLA with different RTC attention schedules.""" + set_seed(42) + + schedules = [ + RTCAttentionSchedule.ZEROS, + RTCAttentionSchedule.ONES, + RTCAttentionSchedule.LINEAR, + RTCAttentionSchedule.EXP, + ] + + config = SmolVLAConfig(max_action_dim=7, chunk_size=50) + + config.input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), + "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + + # Create dataset stats + dataset_stats = { + "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, + "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, + "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, + } + + device = config.device + + for schedule in schedules: + print(f"Testing schedule: {schedule}") + + # Add RTC config with specific schedule + config.rtc_config = RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=5.0, + prefix_attention_schedule=schedule, + debug=False, + ) + + # Instantiate policy + policy = SmolVLAPolicy(config) + policy.eval() + preprocessor, _ = make_pre_post_processors( + policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats + ) + + # Create dummy batch + batch = { + "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), + "task": ["Pick up the object"], + } + batch = preprocessor(batch) + + # Create previous chunk + prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) + + with torch.no_grad(): + noise = policy.model.sample_noise((1, config.chunk_size, 7), device) + actions = policy.predict_action_chunk( + batch, + noise=noise, + prev_chunk_left_over=prev_chunk, + inference_delay=4, + execution_horizon=10, + ) + + # Verify shape + assert actions.shape == (1, config.chunk_size, 7) + print(f" ✓ Schedule {schedule}: Test passed") + + print("✓ SmolVLA RTC different schedules: All schedules tested")