split pi0 and pi05 policy in seperate files

This commit is contained in:
Pepijn
2025-09-11 09:04:46 +02:00
parent d36bdac114
commit 9f7bfeb419
12 changed files with 4491 additions and 94 deletions
+14 -6
View File
@@ -5,6 +5,7 @@
import torch
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
def create_dummy_stats(config):
@@ -46,10 +47,17 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
try:
# Load the model from HuggingFace hub with strict mode
policy = PI0OpenPIPolicy.from_pretrained(
model_id,
strict=True, # Ensure all weights are loaded correctly
)
if model_name == "PI0.5":
policy = PI05OpenPIPolicy.from_pretrained(
model_id,
strict=True, # Ensure all weights are loaded correctly,
)
else:
policy = PI0OpenPIPolicy.from_pretrained(
model_id,
strict=True, # Ensure all weights are loaded correctly,
)
print("✓ Model loaded successfully from HuggingFace hub")
# Inject dummy stats since they aren't loaded from the hub
@@ -69,7 +77,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
# Get model info
print("\nModel configuration:")
print(f" - Model type: {'PI0.5' if policy.config.pi05 else 'PI0'}")
print(f" - Model type: {model_name}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
@@ -81,7 +89,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
print(f" - Dtype: {next(policy.parameters()).dtype}")
# Check model-specific features
if policy.config.pi05:
if model_name == "PI0.5":
print("\nPI0.5 specific features:")
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")