upgrade test, fix failing

This commit is contained in:
Jade Choghari
2025-11-25 20:46:29 +01:00
parent f62cfc9ca2
commit 4e9acd4afe
@@ -18,7 +18,6 @@
# ruff: noqa: E402
import gc
import os
import random
from copy import deepcopy
from typing import Any
@@ -27,20 +26,12 @@ import numpy as np
import pytest
import torch
# Skip if transformers is not available
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires XVLA model access and is not meant for CI",
)
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
from tests.utils import require_package # noqa: E402
# Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
@@ -171,6 +162,8 @@ def preprocessor(xvla_components):
return xvla_components[1]
@require_package("transformers")
@require_package("timm")
def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80)
@@ -201,22 +194,24 @@ def test_xvla_preprocessor_alignment(policy, preprocessor):
for key, expected_shape in expected_shapes.items():
if key in lerobot_inputs:
actual_shape = tuple(lerobot_inputs[key].shape)
print(f"\n🔎 Key: {key}")
print(f" Expected shape: {expected_shape}")
print(f" Actual shape: {actual_shape}")
print(f"\nKey: {key}")
print(f"Expected shape: {expected_shape}")
print(f"Actual shape: {actual_shape}")
if actual_shape == expected_shape:
print(" ✔️ Shape matches!")
print("Shape matches!")
else:
print("Shape mismatch!")
print("Shape mismatch!")
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
else:
print(f"\n⚠️ Key '{key}' not found in inputs!")
print(f"\nKey '{key}' not found in inputs!")
print("\nAll preprocessor outputs have correct shapes!")
print("\nAll preprocessor outputs have correct shapes!")
@require_package("transformers")
@require_package("timm")
def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80)
@@ -283,17 +278,19 @@ def test_xvla_action_generation(policy, preprocessor):
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
for tol in tolerances:
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
status = "✔️" if is_close else ""
print(f"{status} First 5 actions close (atol={tol}): {is_close}")
status = "Success" if is_close else "Failure"
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
# Assert with reasonable tolerance
tolerance = 1e-3
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
f"First 5 actions differ by more than tolerance ({tolerance})"
)
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
@require_package("transformers")
@require_package("timm")
def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80)
@@ -324,16 +321,16 @@ def test_xvla_inference_reproducibility(policy, preprocessor):
print("\nComparing two runs:")
print("-" * 80)
if torch.allclose(actions_1, actions_2, atol=1e-8):
print("✔️ Inference is perfectly reproducible!")
print("Inference is perfectly reproducible!")
else:
diff = torch.abs(actions_1 - actions_2)
print("⚠️ Small differences detected:")
print("Small differences detected:")
print(f" Max diff: {diff.max().item():.6e}")
print(f" Mean diff: {diff.mean().item():.6e}")
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
print("\nInference is reproducible!")
print("\nInference is reproducible!")
if __name__ == "__main__":
@@ -353,13 +350,13 @@ if __name__ == "__main__":
test_xvla_inference_reproducibility(policy, preprocessor)
print("\n" + "=" * 80)
print("All tests passed!")
print("All tests passed!")
print("=" * 80)
cleanup_memory()
except Exception as e:
print("\n" + "=" * 80)
print(f"Test failed with error: {e}")
print(f"Test failed with error: {e}")
print("=" * 80)
cleanup_memory()
raise