more fixes to testing

This commit is contained in:
Jade Choghari
2025-11-25 21:29:52 +01:00
parent 15dc2fd867
commit 81cf4d8ed5
2 changed files with 5 additions and 19 deletions
@@ -26,21 +26,14 @@ import numpy as np
import pytest
import torch
# Conditional import for type checking and lazy loading
from lerobot.utils.import_utils import _timm_available, _transformers_available
if _timm_available and _transformers_available:
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
else:
XVLAConfig = None
XVLAPolicy = None
make_xvla_pre_post_processors = None
pytest.importorskip("timm")
pytest.importorskip("transformers")
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
from tests.utils import require_package # noqa: E402
# Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
@@ -171,8 +164,6 @@ def preprocessor(xvla_components):
return xvla_components[1]
@require_package("transformers")
@require_package("timm")
def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80)
@@ -219,8 +210,6 @@ def test_xvla_preprocessor_alignment(policy, preprocessor):
print("\nAll preprocessor outputs have correct shapes!")
@require_package("transformers")
@require_package("timm")
def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80)
@@ -298,8 +287,6 @@ def test_xvla_action_generation(policy, preprocessor):
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
@require_package("transformers")
@require_package("timm")
def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80)