This commit is contained in:
Pepijn
2025-09-10 21:33:55 +02:00
parent b028907d21
commit ac323b0113
6 changed files with 422 additions and 46 deletions
-35
View File
@@ -154,41 +154,6 @@ def create_and_push_model(
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
# Test loading the model back
print("\n" + "-" * 60)
print("Testing model loading from hub...")
try:
loaded_policy = PI0OpenPIPolicy.from_pretrained(
repo_id,
token=token,
)
print("✓ Model loaded successfully from hub")
# Quick validation
batch_size = 1
device = next(loaded_policy.parameters()).device
test_batch = {
"observation.state": torch.randn(batch_size, config.state_dim, device=device),
"action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device),
"task": ["Test task"],
}
# Add images
for key in config.image_keys:
test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device)
# Test forward pass
loaded_policy.train()
loss, loss_dict = loaded_policy.forward(test_batch)
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"✗ Failed to load model: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("✓ Process complete!")
print("=" * 60)