upgrade test, fix failing

This commit is contained in:
Jade Choghari
2025-11-25 20:46:29 +01:00
parent f62cfc9ca2
commit 4e9acd4afe
@@ -18,7 +18,6 @@
# ruff: noqa: E402 # ruff: noqa: E402
import gc import gc
import os
import random import random
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any
@@ -27,20 +26,12 @@ import numpy as np
import pytest import pytest
import torch import torch
# Skip if transformers is not available
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires XVLA model access and is not meant for CI",
)
from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
from tests.utils import require_package # noqa: E402
# Constants # Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
@@ -171,6 +162,8 @@ def preprocessor(xvla_components):
return xvla_components[1] return xvla_components[1]
@require_package("transformers")
@require_package("timm")
def test_xvla_preprocessor_alignment(policy, preprocessor): def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs.""" """Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80) print("\n" + "=" * 80)
@@ -201,22 +194,24 @@ def test_xvla_preprocessor_alignment(policy, preprocessor):
for key, expected_shape in expected_shapes.items(): for key, expected_shape in expected_shapes.items():
if key in lerobot_inputs: if key in lerobot_inputs:
actual_shape = tuple(lerobot_inputs[key].shape) actual_shape = tuple(lerobot_inputs[key].shape)
print(f"\n🔎 Key: {key}") print(f"\nKey: {key}")
print(f" Expected shape: {expected_shape}") print(f"Expected shape: {expected_shape}")
print(f" Actual shape: {actual_shape}") print(f"Actual shape: {actual_shape}")
if actual_shape == expected_shape: if actual_shape == expected_shape:
print(" ✔️ Shape matches!") print("Shape matches!")
else: else:
print("Shape mismatch!") print("Shape mismatch!")
assert actual_shape == expected_shape, f"Shape mismatch for {key}" assert actual_shape == expected_shape, f"Shape mismatch for {key}"
else: else:
print(f"\n⚠️ Key '{key}' not found in inputs!") print(f"\nKey '{key}' not found in inputs!")
print("\nAll preprocessor outputs have correct shapes!") print("\nAll preprocessor outputs have correct shapes!")
@require_package("transformers")
@require_package("timm")
def test_xvla_action_generation(policy, preprocessor): def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions.""" """Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80) print("\n" + "=" * 80)
@@ -283,17 +278,19 @@ def test_xvla_action_generation(policy, preprocessor):
tolerances = [1e-5, 1e-4, 1e-3, 1e-2] tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
for tol in tolerances: for tol in tolerances:
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol) is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
status = "✔️" if is_close else "" status = "Success" if is_close else "Failure"
print(f"{status} First 5 actions close (atol={tol}): {is_close}") print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
# Assert with reasonable tolerance # Assert with reasonable tolerance
tolerance = 1e-3 tolerance = 1e-3
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), ( assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
f"First 5 actions differ by more than tolerance ({tolerance})" f"First 5 actions differ by more than tolerance ({tolerance})"
) )
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!") print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
@require_package("transformers")
@require_package("timm")
def test_xvla_inference_reproducibility(policy, preprocessor): def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed.""" """Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80) print("\n" + "=" * 80)
@@ -324,16 +321,16 @@ def test_xvla_inference_reproducibility(policy, preprocessor):
print("\nComparing two runs:") print("\nComparing two runs:")
print("-" * 80) print("-" * 80)
if torch.allclose(actions_1, actions_2, atol=1e-8): if torch.allclose(actions_1, actions_2, atol=1e-8):
print("✔️ Inference is perfectly reproducible!") print("Inference is perfectly reproducible!")
else: else:
diff = torch.abs(actions_1 - actions_2) diff = torch.abs(actions_1 - actions_2)
print("⚠️ Small differences detected:") print("Small differences detected:")
print(f" Max diff: {diff.max().item():.6e}") print(f" Max diff: {diff.max().item():.6e}")
print(f" Mean diff: {diff.mean().item():.6e}") print(f" Mean diff: {diff.mean().item():.6e}")
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!" assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
print("\nInference is reproducible!") print("\nInference is reproducible!")
if __name__ == "__main__": if __name__ == "__main__":
@@ -353,13 +350,13 @@ if __name__ == "__main__":
test_xvla_inference_reproducibility(policy, preprocessor) test_xvla_inference_reproducibility(policy, preprocessor)
print("\n" + "=" * 80) print("\n" + "=" * 80)
print("All tests passed!") print("All tests passed!")
print("=" * 80) print("=" * 80)
cleanup_memory() cleanup_memory()
except Exception as e: except Exception as e:
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"Test failed with error: {e}") print(f"Test failed with error: {e}")
print("=" * 80) print("=" * 80)
cleanup_memory() cleanup_memory()
raise raise