This commit is contained in:
Pepijn
2025-09-10 21:33:55 +02:00
parent b028907d21
commit ac323b0113
6 changed files with 422 additions and 46 deletions
+86 -9
View File
@@ -30,14 +30,16 @@ def create_dummy_stats(config):
return dummy_stats
def test_hub_loading():
"""Test loading model from HuggingFace hub."""
print("=" * 60)
print("PI0OpenPI HuggingFace Hub Loading Test")
print("=" * 60)
def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
"""Test loading model from HuggingFace hub.
# Model ID on HuggingFace hub
model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config.
Args:
model_id: HuggingFace model ID to load
model_name: Display name for the model (e.g., "PI0", "PI0.5")
"""
print("=" * 60)
print(f"{model_name} OpenPI HuggingFace Hub Loading Test")
print("=" * 60)
print(f"\nLoading model from: {model_id}")
print("-" * 60)
@@ -67,14 +69,45 @@ def test_hub_loading():
# Get model info
print("\nModel configuration:")
print(f" - Model type: {'PI0.5' if policy.config.pi05 else 'PI0'}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Action horizon: {policy.config.action_horizon}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
print(f" - Device: {device}")
print(f" - Dtype: {next(policy.parameters()).dtype}")
# Check model-specific features
if policy.config.pi05:
print("\nPI0.5 specific features:")
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")
print(f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
# Verify PI0.5 architecture
assert hasattr(policy.model, "time_mlp_in"), "PI0.5 should have time_mlp_in"
assert hasattr(policy.model, "time_mlp_out"), "PI0.5 should have time_mlp_out"
assert not hasattr(policy.model, "state_proj"), "PI0.5 should not have state_proj"
assert not hasattr(policy.model, "action_time_mlp_in"), "PI0.5 should not have action_time_mlp_in"
print(" ✓ PI0.5 architecture verified")
else:
print("\nPI0 specific features:")
print(f" - Has action_time_mlp layers: {hasattr(policy.model, 'action_time_mlp_in')}")
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be True)")
print(
f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms} (should be False)"
)
# Verify PI0 architecture
assert hasattr(policy.model, "action_time_mlp_in"), "PI0 should have action_time_mlp_in"
assert hasattr(policy.model, "action_time_mlp_out"), "PI0 should have action_time_mlp_out"
assert hasattr(policy.model, "state_proj"), "PI0 should have state_proj"
assert not hasattr(policy.model, "time_mlp_in"), "PI0 should not have time_mlp_in"
print(" ✓ PI0 architecture verified")
except Exception as e:
print(f"✗ Failed to load model: {e}")
return False
@@ -177,11 +210,55 @@ def test_hub_loading():
return False
print("\n" + "=" * 60)
print("✓ All tests passed!")
print(f"✓ All tests passed for {model_name}!")
print("=" * 60)
return True
def main():
"""Run tests for both PI0 and PI0.5 models."""
print("\n")
print("" + "" * 58 + "")
print("" + " PI0 & PI0.5 HuggingFace Hub Loading Test Suite ".center(58) + "")
print("" + "" * 58 + "")
print()
results = []
# Test PI0 model
print("\n[Test 1/2] Testing PI0 model...")
print("" * 60)
pi0_success = test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
results.append(("PI0", pi0_success))
# Test PI0.5 model
print("\n\n[Test 2/2] Testing PI0.5 model...")
print("" * 60)
pi05_success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
results.append(("PI0.5", pi05_success))
# Summary
print("\n\n")
print("" + "" * 58 + "")
print("" + " TEST SUMMARY ".center(58) + "")
print("" + "" * 58 + "")
all_passed = True
for model_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED"
print(f" {model_name:10} : {status}")
if not success:
all_passed = False
print()
if all_passed:
print("🎉 All models loaded and tested successfully!")
else:
print("⚠️ Some tests failed. Check the output above for details.")
return all_passed
if __name__ == "__main__":
success = test_hub_loading()
success = main()
exit(0 if success else 1)