mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
Fix tests
This commit is contained in:
@@ -23,13 +23,15 @@ from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedu
|
|||||||
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda, require_package # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
def test_smolvla_rtc_initialization():
|
def test_smolvla_rtc_initialization():
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||||
|
|
||||||
"""Test SmolVLA policy can initialize RTC processor."""
|
"""Test SmolVLA policy can initialize RTC processor."""
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
@@ -63,8 +65,11 @@ def test_smolvla_rtc_initialization():
|
|||||||
print("✓ SmolVLA RTC initialization: Test passed")
|
print("✓ SmolVLA RTC initialization: Test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
def test_smolvla_rtc_initialization_without_rtc_config():
|
def test_smolvla_rtc_initialization_without_rtc_config():
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||||
|
|
||||||
"""Test SmolVLA policy can initialize without RTC config."""
|
"""Test SmolVLA policy can initialize without RTC config."""
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
@@ -82,9 +87,12 @@ def test_smolvla_rtc_initialization_without_rtc_config():
|
|||||||
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||||
def test_smolvla_rtc_inference_with_prev_chunk():
|
def test_smolvla_rtc_inference_with_prev_chunk():
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||||
|
|
||||||
"""Test SmolVLA policy inference with RTC and previous chunk."""
|
"""Test SmolVLA policy inference with RTC and previous chunk."""
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
@@ -162,9 +170,12 @@ def test_smolvla_rtc_inference_with_prev_chunk():
|
|||||||
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||||
def test_smolvla_rtc_inference_without_prev_chunk():
|
def test_smolvla_rtc_inference_without_prev_chunk():
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||||
|
|
||||||
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
@@ -233,9 +244,12 @@ def test_smolvla_rtc_inference_without_prev_chunk():
|
|||||||
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||||
def test_smolvla_rtc_validation_rules():
|
def test_smolvla_rtc_validation_rules():
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||||
|
|
||||||
"""Test SmolVLA policy with RTC follows all three validation rules."""
|
"""Test SmolVLA policy with RTC follows all three validation rules."""
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user