Fix tests

This commit is contained in:
Eugene Mironov
2025-11-19 03:10:27 +07:00
parent 8008dbb02c
commit 59a52e557c
+16 -2
View File
@@ -23,13 +23,15 @@ from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedu
from lerobot.policies.factory import make_pre_post_processors # noqa: E402 from lerobot.policies.factory import make_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401 from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
from lerobot.utils.random_utils import set_seed # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402 from tests.utils import require_cuda, require_package # noqa: E402
@require_package("transformers")
@require_cuda @require_cuda
def test_smolvla_rtc_initialization(): def test_smolvla_rtc_initialization():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize RTC processor.""" """Test SmolVLA policy can initialize RTC processor."""
set_seed(42) set_seed(42)
@@ -63,8 +65,11 @@ def test_smolvla_rtc_initialization():
print("✓ SmolVLA RTC initialization: Test passed") print("✓ SmolVLA RTC initialization: Test passed")
@require_package("transformers")
@require_cuda @require_cuda
def test_smolvla_rtc_initialization_without_rtc_config(): def test_smolvla_rtc_initialization_without_rtc_config():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy can initialize without RTC config.""" """Test SmolVLA policy can initialize without RTC config."""
set_seed(42) set_seed(42)
@@ -82,9 +87,12 @@ def test_smolvla_rtc_initialization_without_rtc_config():
print("✓ SmolVLA RTC initialization without RTC config: Test passed") print("✓ SmolVLA RTC initialization without RTC config: Test passed")
@require_package("transformers")
@require_cuda @require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk(): def test_smolvla_rtc_inference_with_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC and previous chunk.""" """Test SmolVLA policy inference with RTC and previous chunk."""
set_seed(42) set_seed(42)
@@ -162,9 +170,12 @@ def test_smolvla_rtc_inference_with_prev_chunk():
print("✓ SmolVLA RTC inference with prev_chunk: Test passed") print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
@require_package("transformers")
@require_cuda @require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk(): def test_smolvla_rtc_inference_without_prev_chunk():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect).""" """Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
set_seed(42) set_seed(42)
@@ -233,9 +244,12 @@ def test_smolvla_rtc_inference_without_prev_chunk():
print("✓ SmolVLA RTC inference without prev_chunk: Test passed") print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
@require_package("transformers")
@require_cuda @require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules(): def test_smolvla_rtc_validation_rules():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
"""Test SmolVLA policy with RTC follows all three validation rules.""" """Test SmolVLA policy with RTC follows all three validation rules."""
set_seed(42) set_seed(42)