mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
initial commit
This commit is contained in:
+136
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
|
||||
|
||||
def test_hub_loading():
|
||||
"""Test loading model from HuggingFace hub."""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI HuggingFace Hub Loading Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Model ID on HuggingFace hub
|
||||
model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config.
|
||||
|
||||
print(f"\nLoading model from: {model_id}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
# Load the model from HuggingFace hub with strict mode
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly
|
||||
)
|
||||
print("✓ Model loaded successfully from HuggingFace hub")
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||
print(f" - Device: {next(policy.parameters()).device}")
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
return False
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing forward pass with loaded model...")
|
||||
|
||||
# Create dummy batch for testing
|
||||
batch_size = 1
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
# Create dummy dataset stats if not loaded with the model
|
||||
if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None:
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(policy.config.state_dim, device=device),
|
||||
"std": torch.ones(policy.config.state_dim, device=device),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(policy.config.action_dim, device=device),
|
||||
"std": torch.ones(policy.config.action_dim, device=device),
|
||||
},
|
||||
}
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(
|
||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
policy.config.action_horizon,
|
||||
policy.config.action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Add images if they're in the config
|
||||
for key in policy.config.image_keys:
|
||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||
|
||||
try:
|
||||
# Test forward pass
|
||||
policy.train() # Set to training mode for forward pass with loss
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print("✓ Forward pass successful")
|
||||
print(f" - Loss: {loss_dict['loss']:.4f}")
|
||||
print(f" - Loss shape: {loss.shape if hasattr(loss, 'shape') else 'scalar'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing inference with loaded model...")
|
||||
|
||||
try:
|
||||
# Test action prediction
|
||||
policy.eval() # Set to evaluation mode for inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print("✓ Action prediction successful")
|
||||
print(f" - Action shape: {action.shape}")
|
||||
print(f" - Action range: [{action.min().item():.3f}, {action.max().item():.3f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All tests passed!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_hub_loading()
|
||||
exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user