cleanup tests

This commit is contained in:
Pepijn
2025-09-17 17:35:07 +02:00
parent bc10fc7696
commit 64974c38c2
4 changed files with 37 additions and 97 deletions
+4 -61
View File
@@ -8,8 +8,6 @@ import torch
# Skip entire module if transformers is not available # Skip entire module if transformers is not available
pytest.importorskip("transformers") pytest.importorskip("transformers")
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
from tests.utils import require_cuda from tests.utils import require_cuda
@@ -17,7 +15,6 @@ from tests.utils import require_cuda
@require_cuda @require_cuda
def test_pi05_model_architecture(): def test_pi05_model_architecture():
"""Test that pi05=True creates the correct model architecture.""" """Test that pi05=True creates the correct model architecture."""
print("Testing PI0.5 model architecture...")
# Create config # Create config
config = PI05OpenPIConfig( config = PI05OpenPIConfig(
@@ -26,17 +23,12 @@ def test_pi05_model_architecture():
dtype="float32", dtype="float32",
) )
# Verify tokenizer max length is set correctly
assert config.tokenizer_max_length == 200, ( assert config.tokenizer_max_length == 200, (
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}" f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
) )
print(f"✓ Tokenizer max length correctly set to {config.tokenizer_max_length}")
# Verify discrete_state_input defaults to pi05
assert config.discrete_state_input == True, ( # noqa: E712 assert config.discrete_state_input == True, ( # noqa: E712
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}" f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
) )
print(f"✓ discrete_state_input correctly defaults to pi05 value: {config.discrete_state_input}")
# Create dummy dataset stats # Create dummy dataset stats
dataset_stats = { dataset_stats = {
@@ -54,22 +46,18 @@ def test_pi05_model_architecture():
policy = PI05OpenPIPolicy(config, dataset_stats) policy = PI05OpenPIPolicy(config, dataset_stats)
# Verify pi05 model components exist # Verify pi05 model components exist
# Check that time_mlp layers exist (for AdaRMS conditioning) # Check that time_mlp layers exist (for AdaRMS conditioning)
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05" assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05" assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
print("✓ Time MLP layers present for AdaRMS conditioning")
# Check that action_time_mlp layers don't exist (pi0 only) # Check that action_time_mlp layers don't exist (pi0 only)
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode" assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
assert not hasattr(policy.model, "action_time_mlp_out"), ( assert not hasattr(policy.model, "action_time_mlp_out"), (
"action_time_mlp_out should not exist in pi05 mode" "action_time_mlp_out should not exist in pi05 mode"
) )
print("✓ Action-time MLP layers correctly absent")
# Check that state_proj doesn't exist in pi05 mode # Check that state_proj doesn't exist in pi05 mode
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode" assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
print("✓ State projection layer correctly absent")
# Check AdaRMS configuration in the underlying model # Check AdaRMS configuration in the underlying model
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
@@ -79,13 +67,11 @@ def test_pi05_model_architecture():
assert adarms_expert_config == True, ( # noqa: E712 assert adarms_expert_config == True, ( # noqa: E712
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}" f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
) )
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
@require_cuda @require_cuda
def test_pi05_forward_pass(): def test_pi05_forward_pass():
"""Test forward pass with""" """Test forward pass with"""
print("\nTesting PI0.5 forward pass...")
# Create config # Create config
config = PI05OpenPIConfig( config = PI05OpenPIConfig(
@@ -126,64 +112,21 @@ def test_pi05_forward_pass():
# Test forward pass # Test forward pass
try: try:
loss, loss_dict = policy.forward(batch) loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
assert not torch.isnan(loss), "Loss is NaN" assert not torch.isnan(loss), "Loss is NaN"
assert loss.item() >= 0, "Loss should be non-negative" assert loss.item() >= 0, "Loss should be non-negative"
except Exception as e: except Exception as e:
print(f"Forward pass failed: {e}") print(f"Forward pass failed: {e}")
raise raise
# Test action prediction # Test action prediction
try: try:
with torch.no_grad(): with torch.no_grad():
action = policy.select_action(batch) action = policy.select_action(batch)
print(f"Action prediction successful. Action shape: {action.shape}") print(f"Action prediction successful. Action shape: {action.shape}")
# When batch_size > 1, select_action returns (batch_size, action_dim) # When batch_size > 1, select_action returns (batch_size, action_dim)
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}" assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
assert not torch.isnan(action).any(), "Action contains NaN values" assert not torch.isnan(action).any(), "Action contains NaN values"
except Exception as e: except Exception as e:
print(f"Action prediction failed: {e}") print(f"Action prediction failed: {e}")
raise raise
@require_cuda
def test_pi0_vs_pi05_differences():
"""Test key differences between pi0 and pi05 modes."""
print("\nComparing PI0 vs PI0.5 architectures...")
# Create both configurations
config_pi0 = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
config_pi05 = PI05OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
}
# Create both models
policy_pi0 = PI0OpenPIPolicy(config_pi0, dataset_stats)
policy_pi05 = PI05OpenPIPolicy(config_pi05, dataset_stats)
print("\nPI0 Model:")
print(f" - Tokenizer max length: {config_pi0.tokenizer_max_length}")
print(f" - Has state_proj: {hasattr(policy_pi0.model, 'state_proj')}")
print(f" - Has action_time_mlp: {hasattr(policy_pi0.model, 'action_time_mlp_in')}")
print(f" - Has time_mlp: {hasattr(policy_pi0.model, 'time_mlp_in')}")
print(f" - Uses AdaRMS: {policy_pi0.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
print("\nPI0.5 Model:")
print(f" - Tokenizer max length: {config_pi05.tokenizer_max_length}")
print(f" - discrete_state_input: {config_pi05.discrete_state_input}")
print(f" - Has state_proj: {hasattr(policy_pi05.model, 'state_proj')}")
print(f" - Has action_time_mlp: {hasattr(policy_pi05.model, 'action_time_mlp_in')}")
print(f" - Has time_mlp: {hasattr(policy_pi05.model, 'time_mlp_in')}")
print(f" - Uses AdaRMS: {policy_pi05.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
# Count parameters
pi0_params = sum(p.numel() for p in policy_pi0.parameters())
pi05_params = sum(p.numel() for p in policy_pi05.parameters())
print("\nParameter counts:")
print(f" - PI0: {pi0_params:,}")
print(f" - PI0.5: {pi05_params:,}")
print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)")
+6 -14
View File
@@ -15,9 +15,6 @@ from tests.utils import require_cuda
@require_cuda @require_cuda
def test_policy_instantiation(): def test_policy_instantiation():
"""Test basic policy instantiation."""
print("Testing PI0OpenPI policy instantiation...")
# Create config # Create config
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
@@ -35,7 +32,6 @@ def test_policy_instantiation():
# Instantiate policy # Instantiate policy
policy = PI0OpenPIPolicy(config, dataset_stats) policy = PI0OpenPIPolicy(config, dataset_stats)
print(f"Policy created successfully: {policy.name}")
# Test forward pass with dummy data # Test forward pass with dummy data
batch_size = 1 batch_size = 1
@@ -49,39 +45,35 @@ def test_policy_instantiation():
"task": ["Pick up the object"] * batch_size, "task": ["Pick up the object"] * batch_size,
} }
print("\nTesting forward pass...")
try: try:
loss, loss_dict = policy.forward(batch) loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e: except Exception as e:
print(f"Forward pass failed: {e}") print(f"Forward pass failed: {e}")
raise raise
print("\nTesting action prediction...")
try: try:
with torch.no_grad(): with torch.no_grad():
action = policy.select_action(batch) action = policy.select_action(batch)
print(f"Action prediction successful. Action shape: {action.shape}") print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e: except Exception as e:
print(f"Action prediction failed: {e}") print(f"Action prediction failed: {e}")
raise raise
@require_cuda @require_cuda
def test_config_creation(): def test_config_creation():
"""Test policy config creation through factory.""" """Test policy config creation through factory."""
print("\nTesting config creation through factory...")
try: try:
config = make_policy_config( config = make_policy_config(
policy_type="pi0_openpi", policy_type="pi0_openpi",
max_action_dim=7, max_action_dim=7,
max_state_dim=14, max_state_dim=14,
) )
print("Config created successfully through factory") print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}") print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}") print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}") print(f" Action expert variant: {config.action_expert_variant}")
except Exception as e: except Exception as e:
print(f"Config creation failed: {e}") print(f"Config creation failed: {e}")
raise raise
+14 -15
View File
@@ -1,22 +1,27 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation.""" """Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os import os
import pytest import pytest
import torch import torch
# Skip entire module if openpi or transformers is not available # Skip if openpi or transformers is not available
pytest.importorskip("openpi") pytest.importorskip("openpi")
pytest.importorskip("transformers") pytest.importorskip("transformers")
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. # NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from tests.utils import require_cuda
DUMMY_ACTION_DIM = 32 DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32 DUMMY_STATE_DIM = 32
@@ -320,7 +325,6 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
) )
@require_cuda
def test_pi0_original_vs_lerobot(): def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation.""" """Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...") print("Initializing models...")
@@ -333,7 +337,7 @@ def test_pi0_original_vs_lerobot():
batch = create_dummy_data() batch = create_dummy_data()
# Test 1: Each model with its own preprocessing (more realistic end-to-end test) # Test 1: Each model with its own preprocessing (more realistic end-to-end test)
print("\n=== TEST 1: Each model with its own preprocessing ===") print("\nTEST 1: Each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...") print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
@@ -372,7 +376,7 @@ def test_pi0_original_vs_lerobot():
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
# Test 2: Both models with LeRobot preprocessing (isolates model differences) # Test 2: Both models with LeRobot preprocessing (isolates model differences)
print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===") print("\nTEST 2: Both models with LeRobot preprocessing (model comparison)")
print("Creating observation for OpenPI using LeRobot's preprocessing...") print("Creating observation for OpenPI using LeRobot's preprocessing...")
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch) pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
@@ -397,8 +401,3 @@ def test_pi0_original_vs_lerobot():
# Add assertions for pytest # Add assertions for pytest
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}" assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
print("\n=== SUMMARY ===")
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
print("Both tests completed successfully!")
+13 -7
View File
@@ -1,16 +1,25 @@
#!/usr/bin/env python #!/usr/bin/env python
# TODO(pepijn): Remove these tests before merging
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference.""" """Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
import os
import pytest import pytest
import torch import torch
# Skip entire module if transformers is not available # Skip entire module if transformers is not available
pytest.importorskip("transformers") pytest.importorskip("transformers")
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # Skip this entire module in CI
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy pytestmark = pytest.mark.skipif(
from tests.utils import require_cuda os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires HuggingFace authentication and is not meant for CI",
)
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy # noqa: E402
def create_dummy_stats(config): def create_dummy_stats(config):
@@ -36,13 +45,11 @@ def create_dummy_stats(config):
return dummy_stats return dummy_stats
@require_cuda
def test_pi0_hub_loading(): def test_pi0_hub_loading():
"""Test loading PI0 model from HuggingFace hub.""" """Test loading PI0 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0") _test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
@require_cuda
def test_pi05_hub_loading(): def test_pi05_hub_loading():
"""Test loading PI0.5 model from HuggingFace hub.""" """Test loading PI0.5 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5") _test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
@@ -253,7 +260,6 @@ MODEL_TEST_PARAMS = [
] ]
@require_cuda
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS) @pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
def test_all_base_models_hub_loading(model_id, model_type, policy_class): def test_all_base_models_hub_loading(model_id, model_type, policy_class):
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub. """Test loading and basic functionality of all 6 base models from HuggingFace Hub.
@@ -383,4 +389,4 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
print(f"✗ Action prediction failed for {model_id}: {e}") print(f"✗ Action prediction failed for {model_id}: {e}")
raise raise
print(f"All tests passed for {model_id}!") print(f"All tests passed for {model_id}!")