#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test PI0 policy with Real-Time Chunking (RTC) enabled during inference.""" import os import pytest import torch # Skip this entire module in CI pytestmark = pytest.mark.skipif( os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", reason="This test requires local OpenPI installation and is not meant for CI", ) from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402 from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402 from tests.utils import require_cuda # noqa: E402 def validate_rtc_behavior( rtc_actions: torch.Tensor, no_rtc_actions: torch.Tensor, prev_chunk: torch.Tensor, inference_delay: int, execution_horizon: int, rtol: float = 1e-2, ): """Validate RTC behavior follows expected rules. Returns: Tuple of (all_passed, failures) where failures is a list of error messages """ # Remove batch dimension if present and move to CPU rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() no_rtc_actions_t = ( no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() ) prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) failures = [] # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk if inference_delay > 0: delay_end = min(inference_delay, chunk_len) rtc_delay = rtc_actions_t[:delay_end] prev_delay = prev_chunk_t[:delay_end] if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() failures.append( f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" ) # Rule 2: Blend region [inference_delay:execution_horizon] blend_start = inference_delay blend_end = min(execution_horizon, chunk_len) if blend_end > blend_start: rtc_blend = rtc_actions_t[blend_start:blend_end] prev_blend = prev_chunk_t[blend_start:blend_end] no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] min_bound = torch.minimum(prev_blend, no_rtc_blend) max_bound = torch.maximum(prev_blend, no_rtc_blend) within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) if not torch.all(within_bounds): violations = torch.sum(~within_bounds).item() total_elements = within_bounds.numel() failures.append( f"Blend region [{blend_start}:{blend_end}]: " f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" ) # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc if execution_horizon < chunk_len: rtc_after = rtc_actions_t[execution_horizon:chunk_len] no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() failures.append( f"Post-horizon [{execution_horizon}:{chunk_len}]: " f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" ) return len(failures) == 0, failures @require_cuda def test_pi0_rtc_initialization(): """Test PI0 policy can initialize RTC processor.""" set_seed(42) config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32") # Add RTC config config.rtc_config = RTCConfig( enabled=True, execution_horizon=10, max_guidance_weight=5.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=False, ) config.input_features = { "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } # Instantiate policy policy = PI0Policy(config) # Verify RTC processor is initialized assert hasattr(policy, "rtc_processor") assert policy.rtc_processor is not None assert policy.rtc_processor.rtc_config.enabled is True print("✓ PI0 RTC initialization: Test passed") @require_cuda def test_pi0_rtc_initialization_without_rtc_config(): """Test PI0 policy can initialize without RTC config.""" set_seed(42) config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32") # Instantiate policy policy = PI0Policy(config) # Verify RTC processor is not initialized assert hasattr(policy, "rtc_processor") assert policy.rtc_processor is None assert policy.model.rtc_processor is None assert policy._rtc_enabled() is False print("✓ PI0 RTC initialization without RTC config: Test passed") def test_pi0_rtc_inference_with_prev_chunk(): """Test PI0 policy inference with RTC and previous chunk.""" set_seed(42) config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") # Add RTC config config.rtc_config = RTCConfig( enabled=True, execution_horizon=10, max_guidance_weight=5.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=False, ) config.input_features = { "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } # Create dataset stats dataset_stats = { "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } # Instantiate policy and preprocessor policy = PI0Policy(config) policy.eval() preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) device = config.device # Create dummy batch batch = { "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), "task": ["Pick up the object"], } batch = preprocessor(batch) # Create previous chunk prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) with torch.no_grad(): # Use same noise for fair comparison noise = policy.model.sample_noise((1, config.chunk_size, 7), device) # Test with RTC and previous chunk actions_with_rtc = policy.predict_action_chunk( batch, noise=noise.clone(), prev_chunk_left_over=prev_chunk, inference_delay=4, execution_horizon=10, ) # Test without RTC for comparison policy.config.rtc_config.enabled = False actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) policy.config.rtc_config.enabled = True # Verify shapes assert actions_with_rtc.shape == (1, config.chunk_size, 7) assert actions_without_rtc.shape == (1, config.chunk_size, 7) # With previous chunk, actions should be different (RTC guidance applied) assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) print("✓ PI0 RTC inference with prev_chunk: Test passed") @require_cuda def test_pi0_rtc_inference_without_prev_chunk(): """Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect).""" set_seed(42) config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") # Add RTC config config.rtc_config = RTCConfig( enabled=True, execution_horizon=10, max_guidance_weight=5.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=False, ) config.input_features = { "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } # Create dataset stats dataset_stats = { "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } # Instantiate policy and preprocessor policy = PI0Policy(config) policy.eval() preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) device = config.device # Create dummy batch batch = { "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), "task": ["Pick up the object"], } batch = preprocessor(batch) with torch.no_grad(): # Use same noise for fair comparison noise = policy.model.sample_noise((1, config.chunk_size, 7), device) # Test with RTC enabled but no previous chunk actions_with_rtc_no_prev = policy.predict_action_chunk( batch, noise=noise.clone(), prev_chunk_left_over=None, ) # Test without RTC policy.config.rtc_config.enabled = False actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) policy.config.rtc_config.enabled = True # Without previous chunk, RTC should have no effect assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5) print("✓ PI0 RTC inference without prev_chunk: Test passed") @require_cuda def test_pi0_rtc_validation_rules(): """Test PI0 policy with RTC follows all three validation rules.""" set_seed(42) config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") # Add RTC config config.rtc_config = RTCConfig( enabled=True, execution_horizon=10, max_guidance_weight=5.0, prefix_attention_schedule=RTCAttentionSchedule.EXP, debug=False, ) config.input_features = { "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } # Create dataset stats dataset_stats = { "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } # Instantiate policy and preprocessor policy = PI0Policy(config) policy.eval() preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) device = config.device # Create dummy batch batch = { "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), "task": ["Pick up the object"], } batch = preprocessor(batch) # Create previous chunk prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) inference_delay = 4 execution_horizon = 10 with torch.no_grad(): # Use same noise for fair comparison noise = policy.model.sample_noise((1, config.chunk_size, 7), device) # Test with RTC actions_with_rtc = policy.predict_action_chunk( batch, noise=noise.clone(), prev_chunk_left_over=prev_chunk, inference_delay=inference_delay, execution_horizon=execution_horizon, ) # Test without RTC policy.config.rtc_config.enabled = False actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) policy.config.rtc_config.enabled = True # Validate RTC behavior rules all_passed, failures = validate_rtc_behavior( rtc_actions=actions_with_rtc, no_rtc_actions=actions_without_rtc, prev_chunk=prev_chunk, inference_delay=inference_delay, execution_horizon=execution_horizon, ) if not all_passed: error_msg = "RTC validation failed:\n" + "\n".join(failures) pytest.fail(error_msg) print("✓ PI0 RTC validation rules: All rules passed") print(" ✓ Delay region [0:4]: RTC = prev_chunk") print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") print(" ✓ Post-horizon [10:]: RTC = no_rtc") """Test PI0 with different RTC attention schedules.""" set_seed(42) schedules = [ RTCAttentionSchedule.ZEROS, RTCAttentionSchedule.ONES, RTCAttentionSchedule.LINEAR, RTCAttentionSchedule.EXP, ] config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32") config.input_features = { "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } # Create dataset stats dataset_stats = { "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, } device = config.device for schedule in schedules: print(f"Testing schedule: {schedule}") # Add RTC config with specific schedule config.rtc_config = RTCConfig( enabled=True, execution_horizon=10, max_guidance_weight=5.0, prefix_attention_schedule=schedule, debug=False, ) # Instantiate policy policy = PI0Policy(config) policy.eval() preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) # Create dummy batch batch = { "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), "task": ["Pick up the object"], } batch = preprocessor(batch) # Create previous chunk prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) with torch.no_grad(): noise = policy.model.sample_noise((1, config.chunk_size, 7), device) actions = policy.predict_action_chunk( batch, noise=noise, prev_chunk_left_over=prev_chunk, inference_delay=4, execution_horizon=10, ) # Verify shape assert actions.shape == (1, config.chunk_size, 7) print(f" ✓ Schedule {schedule}: Test passed") print("✓ PI0 RTC different schedules: All schedules tested")