mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(dependecies): minimal default tag install
This commit is contained in:
@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import require_package
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
@@ -37,7 +37,7 @@ def test_classifier_output():
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -81,7 +81,7 @@ def test_binary_classifier_with_default_params():
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -123,7 +123,7 @@ def test_multiclass_classifier():
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
@@ -138,7 +138,7 @@ def test_default_device():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
|
||||
@@ -24,10 +24,10 @@ from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda, require_package # noqa: E402
|
||||
from tests.utils import require_cuda, skip_if_package_missing # noqa: E402
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
@@ -65,7 +65,7 @@ def test_smolvla_rtc_initialization():
|
||||
print("✓ SmolVLA RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
@@ -87,7 +87,7 @@ def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
@@ -170,7 +170,7 @@ def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
@@ -244,7 +244,7 @@ def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@skip_if_package_missing("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_validation_rules():
|
||||
|
||||
Reference in New Issue
Block a user