mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
add pi05
This commit is contained in:
@@ -154,41 +154,6 @@ def create_and_push_model(
|
||||
|
||||
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
|
||||
|
||||
# Test loading the model back
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing model loading from hub...")
|
||||
|
||||
try:
|
||||
loaded_policy = PI0OpenPIPolicy.from_pretrained(
|
||||
repo_id,
|
||||
token=token,
|
||||
)
|
||||
print("✓ Model loaded successfully from hub")
|
||||
|
||||
# Quick validation
|
||||
batch_size = 1
|
||||
device = next(loaded_policy.parameters()).device
|
||||
test_batch = {
|
||||
"observation.state": torch.randn(batch_size, config.state_dim, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device),
|
||||
"task": ["Test task"],
|
||||
}
|
||||
|
||||
# Add images
|
||||
for key in config.image_keys:
|
||||
test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device)
|
||||
|
||||
# Test forward pass
|
||||
loaded_policy.train()
|
||||
loss, loss_dict = loaded_policy.forward(test_batch)
|
||||
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ Process complete!")
|
||||
print("=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user