Add test to instatiate all base models

This commit is contained in:
Pepijn
2025-09-16 13:31:29 +02:00
parent 6aaeb7c13f
commit 0e0d6fbfc2
2 changed files with 146 additions and 6 deletions
+3 -6
View File
@@ -51,7 +51,7 @@ def test_policy_instantiation():
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}") print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e: except Exception as e:
print(f"✗ Forward pass failed: {e}") print(f"✗ Forward pass failed: {e}")
return False raise
print("\nTesting action prediction...") print("\nTesting action prediction...")
try: try:
@@ -60,9 +60,7 @@ def test_policy_instantiation():
print(f"✓ Action prediction successful. Action shape: {action.shape}") print(f"✓ Action prediction successful. Action shape: {action.shape}")
except Exception as e: except Exception as e:
print(f"✗ Action prediction failed: {e}") print(f"✗ Action prediction failed: {e}")
return False raise
return True
@require_nightly_gpu @require_nightly_gpu
@@ -80,7 +78,6 @@ def test_config_creation():
print(f" Config type: {type(config).__name__}") print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}") print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}") print(f" Action expert variant: {config.action_expert_variant}")
return True
except Exception as e: except Exception as e:
print(f"✗ Config creation failed: {e}") print(f"✗ Config creation failed: {e}")
return False raise
+143
View File
@@ -2,6 +2,7 @@
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference.""" """Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
import pytest
import torch import torch
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
@@ -234,3 +235,145 @@ def _test_hub_loading(model_id, model_name):
print("\n" + "=" * 60) print("\n" + "=" * 60)
print(f"✓ All tests passed for {model_name}!") print(f"✓ All tests passed for {model_name}!")
print("=" * 60) print("=" * 60)
# Test data for all 6 base models
MODEL_TEST_PARAMS = [
# PI0 models
("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy),
# PI0.5 models
("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy),
]
@require_nightly_gpu
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
Args:
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32")
model_type: Model type ("PI0" or "PI0.5")
policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy)
"""
print(f"\n{'=' * 80}")
print(f"Testing {model_type} model: {model_id}")
print(f"{'=' * 80}")
# Load the model from HuggingFace hub
try:
policy = policy_class.from_pretrained(model_id, strict=True)
print(f"✓ Successfully loaded {model_type} model from {model_id}")
except Exception as e:
print(f"✗ Failed to load model {model_id}: {e}")
raise
# Get model info
device = next(policy.parameters()).device
print("\nModel configuration:")
print(f" - Model ID: {model_id}")
print(f" - Model type: {model_type}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Chunk size: {policy.config.chunk_size}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
print(f" - Device: {device}")
print(f" - Dtype: {next(policy.parameters()).dtype}")
# Verify model-specific architecture
if model_type == "PI0.5":
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
# Verify PI0.5 specific features
assert hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0.5 should have time_mlp_in"
assert hasattr(policy.model, "time_mlp_out"), f"{model_id}: PI0.5 should have time_mlp_out"
assert not hasattr(policy.model, "state_proj"), f"{model_id}: PI0.5 should not have state_proj"
assert not hasattr(policy.model, "action_time_mlp_in"), (
f"{model_id}: PI0.5 should not have action_time_mlp_in"
)
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == True, f"{model_id}: PI0.5 expert should use AdaRMS" # noqa: E712
print(" ✓ PI0.5 architecture verified")
else:
# Verify PI0 specific features
assert hasattr(policy.model, "action_time_mlp_in"), f"{model_id}: PI0 should have action_time_mlp_in"
assert hasattr(policy.model, "action_time_mlp_out"), (
f"{model_id}: PI0 should have action_time_mlp_out"
)
assert hasattr(policy.model, "state_proj"), f"{model_id}: PI0 should have state_proj"
assert not hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0 should not have time_mlp_in"
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == False, f"{model_id}: PI0 expert should not use AdaRMS" # noqa: E712
print(" ✓ PI0 architecture verified")
# Create dummy stats for testing
dummy_stats = create_dummy_stats(policy.config)
for key, stats in dummy_stats.items():
dummy_stats[key] = {
"mean": stats["mean"].to(device),
"std": stats["std"].to(device),
}
# Initialize normalization layers with dummy stats
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
# Create test batch
batch_size = 1
batch = {
"observation.state": torch.randn(
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
),
"action": torch.randn(
batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device
),
"task": ["Pick up the object"] * batch_size,
}
# Add images based on config
for key in policy.config.image_keys:
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
# Test forward pass
print(f"\nTesting forward pass for {model_id}...")
try:
policy.train()
loss, loss_dict = policy.forward(batch)
assert not torch.isnan(loss), f"{model_id}: Forward pass produced NaN loss"
assert loss.item() >= 0, f"{model_id}: Loss should be non-negative"
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"✗ Forward pass failed for {model_id}: {e}")
raise
# Test action prediction
print(f"Testing action prediction for {model_id}...")
try:
policy.eval()
with torch.no_grad():
action = policy.select_action(batch)
expected_shape = (batch_size, policy.config.action_dim)
assert action.shape == expected_shape, (
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
)
assert not torch.isnan(action).any(), f"{model_id}: Action contains NaN values"
print(f"✓ Action prediction successful - Shape: {action.shape}")
except Exception as e:
print(f"✗ Action prediction failed for {model_id}: {e}")
raise
print(f"✅ All tests passed for {model_id}!")