mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
split pi0 and pi05 policy in seperate files
This commit is contained in:
+14
-6
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
@@ -46,10 +47,17 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
|
||||
try:
|
||||
# Load the model from HuggingFace hub with strict mode
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly
|
||||
)
|
||||
if model_name == "PI0.5":
|
||||
policy = PI05OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly,
|
||||
)
|
||||
else:
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly,
|
||||
)
|
||||
|
||||
print("✓ Model loaded successfully from HuggingFace hub")
|
||||
|
||||
# Inject dummy stats since they aren't loaded from the hub
|
||||
@@ -69,7 +77,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - Model type: {'PI0.5' if policy.config.pi05 else 'PI0'}")
|
||||
print(f" - Model type: {model_name}")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
@@ -81,7 +89,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
# Check model-specific features
|
||||
if policy.config.pi05:
|
||||
if model_name == "PI0.5":
|
||||
print("\nPI0.5 specific features:")
|
||||
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
|
||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")
|
||||
|
||||
Reference in New Issue
Block a user