mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
Add test to instatiate all base models
This commit is contained in:
@@ -51,7 +51,7 @@ def test_policy_instantiation():
|
|||||||
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Forward pass failed: {e}")
|
print(f"✗ Forward pass failed: {e}")
|
||||||
return False
|
raise
|
||||||
|
|
||||||
print("\nTesting action prediction...")
|
print("\nTesting action prediction...")
|
||||||
try:
|
try:
|
||||||
@@ -60,9 +60,7 @@ def test_policy_instantiation():
|
|||||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Action prediction failed: {e}")
|
print(f"✗ Action prediction failed: {e}")
|
||||||
return False
|
raise
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@require_nightly_gpu
|
@require_nightly_gpu
|
||||||
@@ -80,7 +78,6 @@ def test_config_creation():
|
|||||||
print(f" Config type: {type(config).__name__}")
|
print(f" Config type: {type(config).__name__}")
|
||||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||||
print(f" Action expert variant: {config.action_expert_variant}")
|
print(f" Action expert variant: {config.action_expert_variant}")
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Config creation failed: {e}")
|
print(f"✗ Config creation failed: {e}")
|
||||||
return False
|
raise
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||||
@@ -234,3 +235,145 @@ def _test_hub_loading(model_id, model_name):
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print(f"✓ All tests passed for {model_name}!")
|
print(f"✓ All tests passed for {model_name}!")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
# Test data for all 6 base models
|
||||||
|
MODEL_TEST_PARAMS = [
|
||||||
|
# PI0 models
|
||||||
|
("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy),
|
||||||
|
("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy),
|
||||||
|
("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy),
|
||||||
|
# PI0.5 models
|
||||||
|
("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy),
|
||||||
|
("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy),
|
||||||
|
("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
|
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
|
||||||
|
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||||
|
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32")
|
||||||
|
model_type: Model type ("PI0" or "PI0.5")
|
||||||
|
policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy)
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Testing {model_type} model: {model_id}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Load the model from HuggingFace hub
|
||||||
|
try:
|
||||||
|
policy = policy_class.from_pretrained(model_id, strict=True)
|
||||||
|
print(f"✓ Successfully loaded {model_type} model from {model_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed to load model {model_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Get model info
|
||||||
|
device = next(policy.parameters()).device
|
||||||
|
print("\nModel configuration:")
|
||||||
|
print(f" - Model ID: {model_id}")
|
||||||
|
print(f" - Model type: {model_type}")
|
||||||
|
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||||
|
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||||
|
print(f" - Action dimension: {policy.config.action_dim}")
|
||||||
|
print(f" - State dimension: {policy.config.state_dim}")
|
||||||
|
print(f" - Chunk size: {policy.config.chunk_size}")
|
||||||
|
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||||
|
print(f" - Device: {device}")
|
||||||
|
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||||
|
|
||||||
|
# Verify model-specific architecture
|
||||||
|
if model_type == "PI0.5":
|
||||||
|
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||||
|
# Verify PI0.5 specific features
|
||||||
|
assert hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0.5 should have time_mlp_in"
|
||||||
|
assert hasattr(policy.model, "time_mlp_out"), f"{model_id}: PI0.5 should have time_mlp_out"
|
||||||
|
assert not hasattr(policy.model, "state_proj"), f"{model_id}: PI0.5 should not have state_proj"
|
||||||
|
assert not hasattr(policy.model, "action_time_mlp_in"), (
|
||||||
|
f"{model_id}: PI0.5 should not have action_time_mlp_in"
|
||||||
|
)
|
||||||
|
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
|
||||||
|
assert adarms_expert_config == True, f"{model_id}: PI0.5 expert should use AdaRMS" # noqa: E712
|
||||||
|
print(" ✓ PI0.5 architecture verified")
|
||||||
|
else:
|
||||||
|
# Verify PI0 specific features
|
||||||
|
assert hasattr(policy.model, "action_time_mlp_in"), f"{model_id}: PI0 should have action_time_mlp_in"
|
||||||
|
assert hasattr(policy.model, "action_time_mlp_out"), (
|
||||||
|
f"{model_id}: PI0 should have action_time_mlp_out"
|
||||||
|
)
|
||||||
|
assert hasattr(policy.model, "state_proj"), f"{model_id}: PI0 should have state_proj"
|
||||||
|
assert not hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0 should not have time_mlp_in"
|
||||||
|
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
|
||||||
|
assert adarms_expert_config == False, f"{model_id}: PI0 expert should not use AdaRMS" # noqa: E712
|
||||||
|
print(" ✓ PI0 architecture verified")
|
||||||
|
|
||||||
|
# Create dummy stats for testing
|
||||||
|
dummy_stats = create_dummy_stats(policy.config)
|
||||||
|
for key, stats in dummy_stats.items():
|
||||||
|
dummy_stats[key] = {
|
||||||
|
"mean": stats["mean"].to(device),
|
||||||
|
"std": stats["std"].to(device),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize normalization layers with dummy stats
|
||||||
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
|
|
||||||
|
policy.normalize_inputs = Normalize(
|
||||||
|
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
policy.normalize_targets = Normalize(
|
||||||
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
policy.unnormalize_outputs = Unnormalize(
|
||||||
|
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test batch
|
||||||
|
batch_size = 1
|
||||||
|
batch = {
|
||||||
|
"observation.state": torch.randn(
|
||||||
|
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||||
|
),
|
||||||
|
"action": torch.randn(
|
||||||
|
batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device
|
||||||
|
),
|
||||||
|
"task": ["Pick up the object"] * batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add images based on config
|
||||||
|
for key in policy.config.image_keys:
|
||||||
|
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
print(f"\nTesting forward pass for {model_id}...")
|
||||||
|
try:
|
||||||
|
policy.train()
|
||||||
|
loss, loss_dict = policy.forward(batch)
|
||||||
|
assert not torch.isnan(loss), f"{model_id}: Forward pass produced NaN loss"
|
||||||
|
assert loss.item() >= 0, f"{model_id}: Loss should be non-negative"
|
||||||
|
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Forward pass failed for {model_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Test action prediction
|
||||||
|
print(f"Testing action prediction for {model_id}...")
|
||||||
|
try:
|
||||||
|
policy.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
action = policy.select_action(batch)
|
||||||
|
expected_shape = (batch_size, policy.config.action_dim)
|
||||||
|
assert action.shape == expected_shape, (
|
||||||
|
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
|
||||||
|
)
|
||||||
|
assert not torch.isnan(action).any(), f"{model_id}: Action contains NaN values"
|
||||||
|
print(f"✓ Action prediction successful - Shape: {action.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Action prediction failed for {model_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"✅ All tests passed for {model_id}!")
|
||||||
|
|||||||
Reference in New Issue
Block a user