feat: only run pi test on GPU

This commit is contained in:
Pepijn
2025-09-17 15:55:58 +02:00
parent 7aebc526b2
commit 256b0e1e3c
4 changed files with 14 additions and 16 deletions
+4 -4
View File
@@ -10,7 +10,7 @@ pytest.importorskip("transformers")
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
from tests.utils import require_nightly_gpu
from tests.utils import require_cuda
def create_dummy_stats(config):
@@ -36,13 +36,13 @@ def create_dummy_stats(config):
return dummy_stats
@require_nightly_gpu
@require_cuda
def test_pi0_hub_loading():
"""Test loading PI0 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
@require_nightly_gpu
@require_cuda
def test_pi05_hub_loading():
"""Test loading PI0.5 model from HuggingFace hub."""
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
@@ -253,7 +253,7 @@ MODEL_TEST_PARAMS = [
]
@require_nightly_gpu
@require_cuda
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.