mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 23:19:48 +00:00
cleanup tests
This commit is contained in:
@@ -1,16 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# TODO(pepijn): Remove these tests before merging
|
||||
|
||||
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if transformers is not available
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
from tests.utils import require_cuda
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires HuggingFace authentication and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # noqa: E402
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy # noqa: E402
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
@@ -36,13 +45,11 @@ def create_dummy_stats(config):
|
||||
return dummy_stats
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_hub_loading():
|
||||
"""Test loading PI0 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_hub_loading():
|
||||
"""Test loading PI0.5 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||
@@ -253,7 +260,6 @@ MODEL_TEST_PARAMS = [
|
||||
]
|
||||
|
||||
|
||||
@require_cuda
|
||||
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
|
||||
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
|
||||
@@ -383,4 +389,4 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
print(f"✗ Action prediction failed for {model_id}: {e}")
|
||||
raise
|
||||
|
||||
print(f"✅ All tests passed for {model_id}!")
|
||||
print(f"All tests passed for {model_id}!")
|
||||
|
||||
Reference in New Issue
Block a user