mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
split pi0 and pi05 policy in seperate files
This commit is contained in:
+14
-16
@@ -4,19 +4,20 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
||||
|
||||
|
||||
def test_pi05_model_architecture():
|
||||
"""Test that pi05=True creates the correct model architecture."""
|
||||
print("Testing PI0.5 model architecture...")
|
||||
|
||||
# Create config with pi05=True
|
||||
config = PI0OpenPIConfig(
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
pi05=True, # Enable PI0.5 mode
|
||||
)
|
||||
|
||||
# Verify tokenizer max length is set correctly
|
||||
@@ -25,7 +26,7 @@ def test_pi05_model_architecture():
|
||||
)
|
||||
print(f"✓ Tokenizer max length correctly set to {config.tokenizer_max_length}")
|
||||
|
||||
# Verify discrete_state_input defaults to pi05 value
|
||||
# Verify discrete_state_input defaults to pi05
|
||||
assert config.discrete_state_input == True, ( # noqa: E712
|
||||
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
|
||||
)
|
||||
@@ -44,11 +45,9 @@ def test_pi05_model_architecture():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
policy = PI05OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Verify pi05 model components exist
|
||||
assert policy.model.pi05 == True, "Model pi05 flag not set" # noqa: E712
|
||||
print("✓ PI0.5 mode enabled in model")
|
||||
|
||||
# Check that time_mlp layers exist (for AdaRMS conditioning)
|
||||
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
|
||||
@@ -80,15 +79,14 @@ def test_pi05_model_architecture():
|
||||
|
||||
|
||||
def test_pi05_forward_pass():
|
||||
"""Test forward pass with pi05=True."""
|
||||
"""Test forward pass with"""
|
||||
print("\nTesting PI0.5 forward pass...")
|
||||
|
||||
# Create config with pi05=True
|
||||
config = PI0OpenPIConfig(
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
pi05=True,
|
||||
action_horizon=16, # Shorter horizon for testing
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
)
|
||||
@@ -106,7 +104,7 @@ def test_pi05_forward_pass():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
policy = PI05OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Create test batch
|
||||
batch_size = 2
|
||||
@@ -150,8 +148,8 @@ def test_pi0_vs_pi05_differences():
|
||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||
|
||||
# Create both configurations
|
||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=False)
|
||||
config_pi05 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32", pi05=True)
|
||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
@@ -160,7 +158,7 @@ def test_pi0_vs_pi05_differences():
|
||||
|
||||
# Create both models
|
||||
policy_pi0 = PI0OpenPIPolicy(config_pi0, dataset_stats)
|
||||
policy_pi05 = PI0OpenPIPolicy(config_pi05, dataset_stats)
|
||||
policy_pi05 = PI05OpenPIPolicy(config_pi05, dataset_stats)
|
||||
|
||||
print("\nPI0 Model:")
|
||||
print(f" - Tokenizer max length: {config_pi0.tokenizer_max_length}")
|
||||
|
||||
Reference in New Issue
Block a user