From 0e0d6fbfc2177956a5385c135e11222a31a4f915 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 16 Sep 2025 13:31:29 +0200 Subject: [PATCH] Add test to instatiate all base models --- tests/policies/test_pi0_openpi.py | 9 +- tests/policies/test_pi0_pi05_hub.py | 143 ++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 6 deletions(-) diff --git a/tests/policies/test_pi0_openpi.py b/tests/policies/test_pi0_openpi.py index 19f0c4346..36c8200b3 100644 --- a/tests/policies/test_pi0_openpi.py +++ b/tests/policies/test_pi0_openpi.py @@ -51,7 +51,7 @@ def test_policy_instantiation(): print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}") except Exception as e: print(f"✗ Forward pass failed: {e}") - return False + raise print("\nTesting action prediction...") try: @@ -60,9 +60,7 @@ def test_policy_instantiation(): print(f"✓ Action prediction successful. Action shape: {action.shape}") except Exception as e: print(f"✗ Action prediction failed: {e}") - return False - - return True + raise @require_nightly_gpu @@ -80,7 +78,6 @@ def test_config_creation(): print(f" Config type: {type(config).__name__}") print(f" PaliGemma variant: {config.paligemma_variant}") print(f" Action expert variant: {config.action_expert_variant}") - return True except Exception as e: print(f"✗ Config creation failed: {e}") - return False + raise diff --git a/tests/policies/test_pi0_pi05_hub.py b/tests/policies/test_pi0_pi05_hub.py index d881f89f3..759caf13b 100644 --- a/tests/policies/test_pi0_pi05_hub.py +++ b/tests/policies/test_pi0_pi05_hub.py @@ -2,6 +2,7 @@ """Test script to load PI0OpenPI model from HuggingFace hub and run inference.""" +import pytest import torch from lerobot.policies.pi0_openpi import PI0OpenPIPolicy @@ -234,3 +235,145 @@ def _test_hub_loading(model_id, model_name): print("\n" + "=" * 60) print(f"✓ All tests passed for {model_name}!") print("=" * 60) + + +# Test data for all 6 base models +MODEL_TEST_PARAMS = [ + # PI0 models + ("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy), + ("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy), + ("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy), + # PI0.5 models + ("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy), + ("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy), + ("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy), +] + + +@require_nightly_gpu +@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS) +def test_all_base_models_hub_loading(model_id, model_type, policy_class): + """Test loading and basic functionality of all 6 base models from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32") + model_type: Model type ("PI0" or "PI0.5") + policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy) + """ + print(f"\n{'=' * 80}") + print(f"Testing {model_type} model: {model_id}") + print(f"{'=' * 80}") + + # Load the model from HuggingFace hub + try: + policy = policy_class.from_pretrained(model_id, strict=True) + print(f"✓ Successfully loaded {model_type} model from {model_id}") + except Exception as e: + print(f"✗ Failed to load model {model_id}: {e}") + raise + + # Get model info + device = next(policy.parameters()).device + print("\nModel configuration:") + print(f" - Model ID: {model_id}") + print(f" - Model type: {model_type}") + print(f" - PaliGemma variant: {policy.config.paligemma_variant}") + print(f" - Action expert variant: {policy.config.action_expert_variant}") + print(f" - Action dimension: {policy.config.action_dim}") + print(f" - State dimension: {policy.config.state_dim}") + print(f" - Chunk size: {policy.config.chunk_size}") + print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") + print(f" - Device: {device}") + print(f" - Dtype: {next(policy.parameters()).dtype}") + + # Verify model-specific architecture + if model_type == "PI0.5": + print(f" - discrete_state_input: {policy.config.discrete_state_input}") + # Verify PI0.5 specific features + assert hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0.5 should have time_mlp_in" + assert hasattr(policy.model, "time_mlp_out"), f"{model_id}: PI0.5 should have time_mlp_out" + assert not hasattr(policy.model, "state_proj"), f"{model_id}: PI0.5 should not have state_proj" + assert not hasattr(policy.model, "action_time_mlp_in"), ( + f"{model_id}: PI0.5 should not have action_time_mlp_in" + ) + adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms + assert adarms_expert_config == True, f"{model_id}: PI0.5 expert should use AdaRMS" # noqa: E712 + print(" ✓ PI0.5 architecture verified") + else: + # Verify PI0 specific features + assert hasattr(policy.model, "action_time_mlp_in"), f"{model_id}: PI0 should have action_time_mlp_in" + assert hasattr(policy.model, "action_time_mlp_out"), ( + f"{model_id}: PI0 should have action_time_mlp_out" + ) + assert hasattr(policy.model, "state_proj"), f"{model_id}: PI0 should have state_proj" + assert not hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0 should not have time_mlp_in" + adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms + assert adarms_expert_config == False, f"{model_id}: PI0 expert should not use AdaRMS" # noqa: E712 + print(" ✓ PI0 architecture verified") + + # Create dummy stats for testing + dummy_stats = create_dummy_stats(policy.config) + for key, stats in dummy_stats.items(): + dummy_stats[key] = { + "mean": stats["mean"].to(device), + "std": stats["std"].to(device), + } + + # Initialize normalization layers with dummy stats + from lerobot.policies.normalize import Normalize, Unnormalize + + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, dummy_stats + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, dummy_stats + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, dummy_stats + ) + + # Create test batch + batch_size = 1 + batch = { + "observation.state": torch.randn( + batch_size, policy.config.state_dim, dtype=torch.float32, device=device + ), + "action": torch.randn( + batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device + ), + "task": ["Pick up the object"] * batch_size, + } + + # Add images based on config + for key in policy.config.image_keys: + batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device) + + # Test forward pass + print(f"\nTesting forward pass for {model_id}...") + try: + policy.train() + loss, loss_dict = policy.forward(batch) + assert not torch.isnan(loss), f"{model_id}: Forward pass produced NaN loss" + assert loss.item() >= 0, f"{model_id}: Loss should be non-negative" + print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"✗ Forward pass failed for {model_id}: {e}") + raise + + # Test action prediction + print(f"Testing action prediction for {model_id}...") + try: + policy.eval() + with torch.no_grad(): + action = policy.select_action(batch) + expected_shape = (batch_size, policy.config.action_dim) + assert action.shape == expected_shape, ( + f"{model_id}: Expected action shape {expected_shape}, got {action.shape}" + ) + assert not torch.isnan(action).any(), f"{model_id}: Action contains NaN values" + print(f"✓ Action prediction successful - Shape: {action.shape}") + except Exception as e: + print(f"✗ Action prediction failed for {model_id}: {e}") + raise + + print(f"✅ All tests passed for {model_id}!")