mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
feat: only run pi test on GPU
This commit is contained in:
@@ -11,10 +11,10 @@ pytest.importorskip("transformers")
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
from tests.utils import require_cuda
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi05_model_architecture():
|
||||
"""Test that pi05=True creates the correct model architecture."""
|
||||
print("Testing PI0.5 model architecture...")
|
||||
@@ -82,7 +82,7 @@ def test_pi05_model_architecture():
|
||||
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi05_forward_pass():
|
||||
"""Test forward pass with"""
|
||||
print("\nTesting PI0.5 forward pass...")
|
||||
@@ -146,7 +146,7 @@ def test_pi05_forward_pass():
|
||||
raise
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi0_vs_pi05_differences():
|
||||
"""Test key differences between pi0 and pi05 modes."""
|
||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||
|
||||
@@ -16,7 +16,7 @@ from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
from tests.utils import require_cuda
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
@@ -318,7 +318,7 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
)
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
|
||||
@@ -10,7 +10,7 @@ pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
from tests.utils import require_cuda
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
@@ -36,13 +36,13 @@ def create_dummy_stats(config):
|
||||
return dummy_stats
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi0_hub_loading():
|
||||
"""Test loading PI0 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
def test_pi05_hub_loading():
|
||||
"""Test loading PI0.5 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||
@@ -253,7 +253,7 @@ MODEL_TEST_PARAMS = [
|
||||
]
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
@require_cuda
|
||||
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
|
||||
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
|
||||
|
||||
+4
-6
@@ -169,17 +169,15 @@ def require_package_arg(func):
|
||||
|
||||
def require_nightly_gpu(func):
|
||||
"""
|
||||
Decorator that skips the test unless running in nightly environment with GPU.
|
||||
Combines GPU availability check with nightly workflow detection.
|
||||
Decorator that skips the test if GPU is not available.
|
||||
Renamed from require_nightly_gpu to maintain backward compatibility,
|
||||
but now only requires GPU availability (not nightly workflow).
|
||||
"""
|
||||
|
||||
@require_cuda
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Check if running in nightly workflow (GitHub Actions)
|
||||
is_nightly = os.environ.get("GITHUB_WORKFLOW") == "Nightly"
|
||||
if not is_nightly:
|
||||
pytest.skip("Test only runs in nightly workflow with GPU")
|
||||
# Only check for GPU availability, no longer require nightly workflow
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user