diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index 9baaa85cd..ac37af146 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -18,7 +18,6 @@ # ruff: noqa: E402 import gc -import os import random from copy import deepcopy from typing import Any @@ -27,20 +26,12 @@ import numpy as np import pytest import torch -# Skip if transformers is not available -pytest.importorskip("transformers") - -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires XVLA model access and is not meant for CI", -) - from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.modeling_xvla import XVLAPolicy from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 +from tests.utils import require_package # noqa: E402 # Constants DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension @@ -171,6 +162,8 @@ def preprocessor(xvla_components): return xvla_components[1] +@require_package("transformers") +@require_package("timm") def test_xvla_preprocessor_alignment(policy, preprocessor): """Test that LeRobot XVLA preprocessor produces expected outputs.""" print("\n" + "=" * 80) @@ -201,22 +194,24 @@ def test_xvla_preprocessor_alignment(policy, preprocessor): for key, expected_shape in expected_shapes.items(): if key in lerobot_inputs: actual_shape = tuple(lerobot_inputs[key].shape) - print(f"\nšŸ”Ž Key: {key}") - print(f" Expected shape: {expected_shape}") - print(f" Actual shape: {actual_shape}") + print(f"\nKey: {key}") + print(f"Expected shape: {expected_shape}") + print(f"Actual shape: {actual_shape}") if actual_shape == expected_shape: - print(" āœ”ļø Shape matches!") + print("Shape matches!") else: - print(" āŒ Shape mismatch!") + print("Shape mismatch!") assert actual_shape == expected_shape, f"Shape mismatch for {key}" else: - print(f"\nāš ļø Key '{key}' not found in inputs!") + print(f"\nKey '{key}' not found in inputs!") - print("\nāœ… All preprocessor outputs have correct shapes!") + print("\nAll preprocessor outputs have correct shapes!") +@require_package("transformers") +@require_package("timm") def test_xvla_action_generation(policy, preprocessor): """Test XVLA LeRobot implementation generates expected actions.""" print("\n" + "=" * 80) @@ -283,17 +278,19 @@ def test_xvla_action_generation(policy, preprocessor): tolerances = [1e-5, 1e-4, 1e-3, 1e-2] for tol in tolerances: is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol) - status = "āœ”ļø" if is_close else "āŒ" - print(f"{status} First 5 actions close (atol={tol}): {is_close}") + status = "Success" if is_close else "Failure" + print(f"{status}: First 5 actions close (atol={tol}): {is_close}") # Assert with reasonable tolerance tolerance = 1e-3 assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), ( f"First 5 actions differ by more than tolerance ({tolerance})" ) - print(f"\nāœ… Success: Actions match expected values within tolerance ({tolerance})!") + print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!") +@require_package("transformers") +@require_package("timm") def test_xvla_inference_reproducibility(policy, preprocessor): """Test that XVLA inference is reproducible with the same seed.""" print("\n" + "=" * 80) @@ -324,16 +321,16 @@ def test_xvla_inference_reproducibility(policy, preprocessor): print("\nComparing two runs:") print("-" * 80) if torch.allclose(actions_1, actions_2, atol=1e-8): - print("āœ”ļø Inference is perfectly reproducible!") + print("Inference is perfectly reproducible!") else: diff = torch.abs(actions_1 - actions_2) - print("āš ļø Small differences detected:") + print("Small differences detected:") print(f" Max diff: {diff.max().item():.6e}") print(f" Mean diff: {diff.mean().item():.6e}") assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!" - print("\nāœ… Inference is reproducible!") + print("\nInference is reproducible!") if __name__ == "__main__": @@ -353,13 +350,13 @@ if __name__ == "__main__": test_xvla_inference_reproducibility(policy, preprocessor) print("\n" + "=" * 80) - print("āœ… All tests passed!") + print("All tests passed!") print("=" * 80) cleanup_memory() except Exception as e: print("\n" + "=" * 80) - print(f"āŒ Test failed with error: {e}") + print(f"Test failed with error: {e}") print("=" * 80) cleanup_memory() raise