mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
add pi05
This commit is contained in:
@@ -154,41 +154,6 @@ def create_and_push_model(
|
||||
|
||||
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
|
||||
|
||||
# Test loading the model back
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing model loading from hub...")
|
||||
|
||||
try:
|
||||
loaded_policy = PI0OpenPIPolicy.from_pretrained(
|
||||
repo_id,
|
||||
token=token,
|
||||
)
|
||||
print("✓ Model loaded successfully from hub")
|
||||
|
||||
# Quick validation
|
||||
batch_size = 1
|
||||
device = next(loaded_policy.parameters()).device
|
||||
test_batch = {
|
||||
"observation.state": torch.randn(batch_size, config.state_dim, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device),
|
||||
"task": ["Test task"],
|
||||
}
|
||||
|
||||
# Add images
|
||||
for key in config.image_keys:
|
||||
test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device)
|
||||
|
||||
# Test forward pass
|
||||
loaded_policy.train()
|
||||
loss, loss_dict = loaded_policy.forward(test_batch)
|
||||
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ Process complete!")
|
||||
print("=" * 60)
|
||||
|
||||
@@ -45,6 +45,19 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
# pip install transformers==4.53.2
|
||||
|
||||
|
||||
# Comparison of PI0 vs PI0.5
|
||||
#
|
||||
# Feature | PI0 | PI0.5
|
||||
# ---------------------|---------------------------------------------|-----------------------------------------
|
||||
# State Embedding | Uses state_proj layer | No state embedding
|
||||
# Time Conditioning | Concatenates time with actions via | Uses time_mlp_* for AdaRMS conditioning
|
||||
# | action_time_mlp_* |
|
||||
# AdaRMS | Not used | Used in action expert
|
||||
# Tokenizer Length | 200 tokens | 48 tokens
|
||||
# discrete_state_input | False | True
|
||||
# Parameter Count | Higher (includes state_proj) | Lower (no state embedding)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0_openpi")
|
||||
@dataclass
|
||||
class PI0OpenPIConfig(PreTrainedConfig):
|
||||
@@ -52,6 +65,7 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
pi05: bool = False # Whether to use PI0.5 variant with AdaRMS
|
||||
discrete_state_input: bool | None = None # Whether to use discrete state input (defaults to pi05 value)
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
# Input / output structure
|
||||
@@ -108,6 +122,16 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Set discrete_state_input to pi05 value if not explicitly set
|
||||
if self.discrete_state_input is None: # see openpi `Pi0Config, __post_init__`
|
||||
object.__setattr__(self, "discrete_state_input", self.pi05)
|
||||
|
||||
# Set tokenizer max length based on pi05 mode, see openpi `Pi0Config, __post_init__`
|
||||
if self.pi05:
|
||||
self.tokenizer_max_length = 48
|
||||
else:
|
||||
self.tokenizer_max_length = 200
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
raise ValueError(
|
||||
|
||||
@@ -924,11 +924,14 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
print(f"Could not load state dict from remote files: {e}")
|
||||
return model
|
||||
|
||||
# Create a new state dict with "model." prefix for all keys that don't already have it
|
||||
# First, fix any pi05-specific key differences
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
|
||||
for key, value in original_state_dict.items():
|
||||
for key, value in fixed_state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
@@ -975,6 +978,59 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
|
||||
return model
|
||||
|
||||
def _fix_pytorch_state_dict_keys(
|
||||
self, state_dict, model_config
|
||||
): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
||||
"""Fix state dict keys to match current model architecture."""
|
||||
import re
|
||||
|
||||
fixed_state_dict = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
||||
# For gemma expert layers
|
||||
if re.match(
|
||||
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
||||
key,
|
||||
):
|
||||
# This key structure suggests old model without adaRMS - keep as is or skip
|
||||
logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}")
|
||||
continue
|
||||
|
||||
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
||||
# Skip old norm structure
|
||||
logging.warning(f"Skipping old norm key (no adaRMS support): {key}")
|
||||
continue
|
||||
|
||||
# Handle MLP naming changes for pi05 vs non-pi05
|
||||
if model_config.pi05:
|
||||
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
|
||||
if key.startswith("action_time_mlp_in."):
|
||||
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
|
||||
elif key.startswith("action_time_mlp_out."):
|
||||
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
|
||||
# Also handle state_proj which shouldn't exist in pi05
|
||||
if key.startswith("state_proj."):
|
||||
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
|
||||
continue
|
||||
else:
|
||||
# non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_*
|
||||
if key.startswith("time_mlp_in."):
|
||||
new_key = key.replace("time_mlp_in.", "action_time_mlp_in.")
|
||||
elif key.startswith("time_mlp_out."):
|
||||
new_key = key.replace("time_mlp_out.", "action_time_mlp_out.")
|
||||
|
||||
# Handle vision tower embedding layer potential differences
|
||||
if "patch_embedding" in key:
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params`
|
||||
return self.parameters()
|
||||
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python
|
||||
# TODO(pepijn): delete this file
|
||||
"""Quick test script to load and test only the PI0.5 model from HuggingFace hub."""
|
||||
|
||||
from test_pi0_hub import test_hub_loading
|
||||
|
||||
|
||||
def main():
|
||||
"""Test only the PI0.5 model."""
|
||||
print("\n")
|
||||
print("=" * 60)
|
||||
print("PI0.5 Model Quick Test")
|
||||
print("=" * 60)
|
||||
|
||||
success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||
|
||||
if success:
|
||||
print("\n✅ PI0.5 model loaded and tested successfully!")
|
||||
else:
|
||||
print("\n❌ PI0.5 test failed!")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
exit(0 if success else 1)
|
||||
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
|
||||
|
||||
def test_pi05_model_architecture():
|
||||
"""Test that pi05=True creates the correct model architecture."""
|
||||
print("Testing PI0.5 model architecture...")
|
||||
|
||||
# Create config with pi05=True
|
||||
config = PI0OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
pi05=True, # Enable PI0.5 mode
|
||||
)
|
||||
|
||||
# Verify tokenizer max length is set correctly
|
||||
assert config.tokenizer_max_length == 48, (
|
||||
f"Expected tokenizer_max_length=48 for pi05, got {config.tokenizer_max_length}"
|
||||
)
|
||||
print(f"✓ Tokenizer max length correctly set to {config.tokenizer_max_length}")
|
||||
|
||||
# Verify discrete_state_input defaults to pi05 value
|
||||
assert config.discrete_state_input == True, ( # noqa: E712
|
||||
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
|
||||
)
|
||||
print(f"✓ discrete_state_input correctly defaults to pi05 value: {config.discrete_state_input}")
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Verify pi05 model components exist
|
||||
assert policy.model.pi05 == True, "Model pi05 flag not set" # noqa: E712
|
||||
print("✓ PI0.5 mode enabled in model")
|
||||
|
||||
# Check that time_mlp layers exist (for AdaRMS conditioning)
|
||||
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
|
||||
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
|
||||
print("✓ Time MLP layers present for AdaRMS conditioning")
|
||||
|
||||
# Check that action_time_mlp layers don't exist (pi0 only)
|
||||
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
|
||||
assert not hasattr(policy.model, "action_time_mlp_out"), (
|
||||
"action_time_mlp_out should not exist in pi05 mode"
|
||||
)
|
||||
print("✓ Action-time MLP layers correctly absent")
|
||||
|
||||
# Check that state_proj doesn't exist in pi05 mode
|
||||
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
|
||||
print("✓ State projection layer correctly absent")
|
||||
|
||||
# Check AdaRMS configuration in the underlying model
|
||||
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
|
||||
assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712
|
||||
|
||||
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
|
||||
assert adarms_expert_config == True, ( # noqa: E712
|
||||
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
|
||||
)
|
||||
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_pi05_forward_pass():
|
||||
"""Test forward pass with pi05=True."""
|
||||
print("\nTesting PI0.5 forward pass...")
|
||||
|
||||
# Create config with pi05=True
|
||||
config = PI0OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
pi05=True,
|
||||
action_horizon=16, # Shorter horizon for testing
|
||||
)
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Create test batch
|
||||
batch_size = 2
|
||||
device = next(policy.parameters()).device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Test forward pass
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
assert not torch.isnan(loss), "Loss is NaN"
|
||||
assert loss.item() >= 0, "Loss should be non-negative"
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
return False
|
||||
|
||||
# Test action prediction
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||
assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}"
|
||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_pi0_vs_pi05_differences():
|
||||
"""Test key differences between pi0 and pi05 modes."""
|
||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||
|
||||
# Create both configurations
|
||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=False)
|
||||
config_pi05 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=True)
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
}
|
||||
|
||||
# Create both models
|
||||
policy_pi0 = PI0OpenPIPolicy(config_pi0, dataset_stats)
|
||||
policy_pi05 = PI0OpenPIPolicy(config_pi05, dataset_stats)
|
||||
|
||||
print("\nPI0 Model:")
|
||||
print(f" - Tokenizer max length: {config_pi0.tokenizer_max_length}")
|
||||
print(f" - discrete_state_input: {config_pi0.discrete_state_input}")
|
||||
print(f" - Has state_proj: {hasattr(policy_pi0.model, 'state_proj')}")
|
||||
print(f" - Has action_time_mlp: {hasattr(policy_pi0.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has time_mlp: {hasattr(policy_pi0.model, 'time_mlp_in')}")
|
||||
print(f" - Uses AdaRMS: {policy_pi0.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
print("\nPI0.5 Model:")
|
||||
print(f" - Tokenizer max length: {config_pi05.tokenizer_max_length}")
|
||||
print(f" - discrete_state_input: {config_pi05.discrete_state_input}")
|
||||
print(f" - Has state_proj: {hasattr(policy_pi05.model, 'state_proj')}")
|
||||
print(f" - Has action_time_mlp: {hasattr(policy_pi05.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has time_mlp: {hasattr(policy_pi05.model, 'time_mlp_in')}")
|
||||
print(f" - Uses AdaRMS: {policy_pi05.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
# Count parameters
|
||||
pi0_params = sum(p.numel() for p in policy_pi0.parameters())
|
||||
pi05_params = sum(p.numel() for p in policy_pi05.parameters())
|
||||
|
||||
print("\nParameter counts:")
|
||||
print(f" - PI0: {pi0_params:,}")
|
||||
print(f" - PI0.5: {pi05_params:,}")
|
||||
print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all PI0.5 tests."""
|
||||
print("=" * 60)
|
||||
print("PI0.5 Support Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("Model Architecture", test_pi05_model_architecture),
|
||||
("Forward Pass", test_pi05_forward_pass),
|
||||
("PI0 vs PI0.5 Comparison", test_pi0_vs_pi05_differences),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for test_name, test_func in tests:
|
||||
print(f"\n[{test_name}]")
|
||||
print("-" * 40)
|
||||
try:
|
||||
if not test_func():
|
||||
all_passed = False
|
||||
print(f"✗ {test_name} failed")
|
||||
except Exception as e:
|
||||
all_passed = False
|
||||
print(f"✗ {test_name} failed with exception: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("✅ All PI0.5 tests passed!")
|
||||
else:
|
||||
print("❌ Some tests failed.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+86
-9
@@ -30,14 +30,16 @@ def create_dummy_stats(config):
|
||||
return dummy_stats
|
||||
|
||||
|
||||
def test_hub_loading():
|
||||
"""Test loading model from HuggingFace hub."""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI HuggingFace Hub Loading Test")
|
||||
print("=" * 60)
|
||||
def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
"""Test loading model from HuggingFace hub.
|
||||
|
||||
# Model ID on HuggingFace hub
|
||||
model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config.
|
||||
Args:
|
||||
model_id: HuggingFace model ID to load
|
||||
model_name: Display name for the model (e.g., "PI0", "PI0.5")
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(f"{model_name} OpenPI HuggingFace Hub Loading Test")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nLoading model from: {model_id}")
|
||||
print("-" * 60)
|
||||
@@ -67,14 +69,45 @@ def test_hub_loading():
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - Model type: {'PI0.5' if policy.config.pi05 else 'PI0'}")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||
print(f" - Device: {device}")
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
# Check model-specific features
|
||||
if policy.config.pi05:
|
||||
print("\nPI0.5 specific features:")
|
||||
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
|
||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")
|
||||
print(f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
# Verify PI0.5 architecture
|
||||
assert hasattr(policy.model, "time_mlp_in"), "PI0.5 should have time_mlp_in"
|
||||
assert hasattr(policy.model, "time_mlp_out"), "PI0.5 should have time_mlp_out"
|
||||
assert not hasattr(policy.model, "state_proj"), "PI0.5 should not have state_proj"
|
||||
assert not hasattr(policy.model, "action_time_mlp_in"), "PI0.5 should not have action_time_mlp_in"
|
||||
print(" ✓ PI0.5 architecture verified")
|
||||
else:
|
||||
print("\nPI0 specific features:")
|
||||
print(f" - Has action_time_mlp layers: {hasattr(policy.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be True)")
|
||||
print(
|
||||
f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms} (should be False)"
|
||||
)
|
||||
|
||||
# Verify PI0 architecture
|
||||
assert hasattr(policy.model, "action_time_mlp_in"), "PI0 should have action_time_mlp_in"
|
||||
assert hasattr(policy.model, "action_time_mlp_out"), "PI0 should have action_time_mlp_out"
|
||||
assert hasattr(policy.model, "state_proj"), "PI0 should have state_proj"
|
||||
assert not hasattr(policy.model, "time_mlp_in"), "PI0 should not have time_mlp_in"
|
||||
print(" ✓ PI0 architecture verified")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
return False
|
||||
@@ -177,11 +210,55 @@ def test_hub_loading():
|
||||
return False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All tests passed!")
|
||||
print(f"✓ All tests passed for {model_name}!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run tests for both PI0 and PI0.5 models."""
|
||||
print("\n")
|
||||
print("╔" + "═" * 58 + "╗")
|
||||
print("║" + " PI0 & PI0.5 HuggingFace Hub Loading Test Suite ".center(58) + "║")
|
||||
print("╚" + "═" * 58 + "╝")
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
# Test PI0 model
|
||||
print("\n[Test 1/2] Testing PI0 model...")
|
||||
print("─" * 60)
|
||||
pi0_success = test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
||||
results.append(("PI0", pi0_success))
|
||||
|
||||
# Test PI0.5 model
|
||||
print("\n\n[Test 2/2] Testing PI0.5 model...")
|
||||
print("─" * 60)
|
||||
pi05_success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||
results.append(("PI0.5", pi05_success))
|
||||
|
||||
# Summary
|
||||
print("\n\n")
|
||||
print("╔" + "═" * 58 + "╗")
|
||||
print("║" + " TEST SUMMARY ".center(58) + "║")
|
||||
print("╚" + "═" * 58 + "╝")
|
||||
|
||||
all_passed = True
|
||||
for model_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f" {model_name:10} : {status}")
|
||||
if not success:
|
||||
all_passed = False
|
||||
|
||||
print()
|
||||
if all_passed:
|
||||
print("🎉 All models loaded and tested successfully!")
|
||||
else:
|
||||
print("⚠️ Some tests failed. Check the output above for details.")
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_hub_loading()
|
||||
success = main()
|
||||
exit(0 if success else 1)
|
||||
|
||||
Reference in New Issue
Block a user