feat: only run pi test on GPU

This commit is contained in:
Pepijn
2025-09-17 15:55:58 +02:00
parent 7aebc526b2
commit 256b0e1e3c
4 changed files with 14 additions and 16 deletions
+4 -4
View File
@@ -11,10 +11,10 @@ pytest.importorskip("transformers")
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
from tests.utils import require_nightly_gpu
from tests.utils import require_cuda
@require_nightly_gpu
@require_cuda
def test_pi05_model_architecture():
"""Test that pi05=True creates the correct model architecture."""
print("Testing PI0.5 model architecture...")
@@ -82,7 +82,7 @@ def test_pi05_model_architecture():
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
@require_nightly_gpu
@require_cuda
def test_pi05_forward_pass():
"""Test forward pass with"""
print("\nTesting PI0.5 forward pass...")
@@ -146,7 +146,7 @@ def test_pi05_forward_pass():
raise
@require_nightly_gpu
@require_cuda
def test_pi0_vs_pi05_differences():
"""Test key differences between pi0 and pi05 modes."""
print("\nComparing PI0 vs PI0.5 architectures...")
@@ -16,7 +16,7 @@ from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
from transformers import AutoTokenizer
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
from tests.utils import require_nightly_gpu
from tests.utils import require_cuda
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
@@ -318,7 +318,7 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
)
@require_nightly_gpu
@require_cuda
def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...")
+4 -4
View File
@@ -10,7 +10,7 @@ pytest.importorskip("transformers")
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
from tests.utils import require_nightly_gpu
from tests.utils import require_cuda
def create_dummy_stats(config):
@@ -36,13 +36,13 @@ def create_dummy_stats(config):
return dummy_stats
@require_nightly_gpu
@require_cuda
def test_pi0_hub_loading():
"""Test loading PI0 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
@require_nightly_gpu
@require_cuda
def test_pi05_hub_loading():
"""Test loading PI0.5 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
@@ -253,7 +253,7 @@ MODEL_TEST_PARAMS = [
]
@require_nightly_gpu
@require_cuda
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
+4 -6
View File
@@ -169,17 +169,15 @@ def require_package_arg(func):
def require_nightly_gpu(func):
"""
Decorator that skips the test unless running in nightly environment with GPU.
Combines GPU availability check with nightly workflow detection.
Decorator that skips the test if GPU is not available.
Renamed from require_nightly_gpu to maintain backward compatibility,
but now only requires GPU availability (not nightly workflow).
"""
@require_cuda
@wraps(func)
def wrapper(*args, **kwargs):
# Check if running in nightly workflow (GitHub Actions)
is_nightly = os.environ.get("GITHUB_WORKFLOW") == "Nightly"
if not is_nightly:
pytest.skip("Test only runs in nightly workflow with GPU")
# Only check for GPU availability, no longer require nightly workflow
return func(*args, **kwargs)
return wrapper