mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
upgrade test, fix failing
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
# ruff: noqa: E402
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
@@ -27,20 +26,12 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if transformers is not available
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires XVLA model access and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
from tests.utils import require_package # noqa: E402
|
||||
|
||||
# Constants
|
||||
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
||||
@@ -171,6 +162,8 @@ def preprocessor(xvla_components):
|
||||
return xvla_components[1]
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("timm")
|
||||
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -201,22 +194,24 @@ def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||
for key, expected_shape in expected_shapes.items():
|
||||
if key in lerobot_inputs:
|
||||
actual_shape = tuple(lerobot_inputs[key].shape)
|
||||
print(f"\n🔎 Key: {key}")
|
||||
print(f" Expected shape: {expected_shape}")
|
||||
print(f" Actual shape: {actual_shape}")
|
||||
print(f"\nKey: {key}")
|
||||
print(f"Expected shape: {expected_shape}")
|
||||
print(f"Actual shape: {actual_shape}")
|
||||
|
||||
if actual_shape == expected_shape:
|
||||
print(" ✔️ Shape matches!")
|
||||
print("Shape matches!")
|
||||
else:
|
||||
print(" ❌ Shape mismatch!")
|
||||
print("Shape mismatch!")
|
||||
|
||||
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
||||
else:
|
||||
print(f"\n⚠️ Key '{key}' not found in inputs!")
|
||||
print(f"\nKey '{key}' not found in inputs!")
|
||||
|
||||
print("\n✅ All preprocessor outputs have correct shapes!")
|
||||
print("\nAll preprocessor outputs have correct shapes!")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("timm")
|
||||
def test_xvla_action_generation(policy, preprocessor):
|
||||
"""Test XVLA LeRobot implementation generates expected actions."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -283,17 +278,19 @@ def test_xvla_action_generation(policy, preprocessor):
|
||||
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||
for tol in tolerances:
|
||||
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
||||
status = "✔️" if is_close else "❌"
|
||||
print(f"{status} First 5 actions close (atol={tol}): {is_close}")
|
||||
status = "Success" if is_close else "Failure"
|
||||
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
|
||||
|
||||
# Assert with reasonable tolerance
|
||||
tolerance = 1e-3
|
||||
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
||||
f"First 5 actions differ by more than tolerance ({tolerance})"
|
||||
)
|
||||
print(f"\n✅ Success: Actions match expected values within tolerance ({tolerance})!")
|
||||
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("timm")
|
||||
def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||
"""Test that XVLA inference is reproducible with the same seed."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -324,16 +321,16 @@ def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||
print("\nComparing two runs:")
|
||||
print("-" * 80)
|
||||
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||
print("✔️ Inference is perfectly reproducible!")
|
||||
print("Inference is perfectly reproducible!")
|
||||
else:
|
||||
diff = torch.abs(actions_1 - actions_2)
|
||||
print("⚠️ Small differences detected:")
|
||||
print("Small differences detected:")
|
||||
print(f" Max diff: {diff.max().item():.6e}")
|
||||
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||
|
||||
print("\n✅ Inference is reproducible!")
|
||||
print("\nInference is reproducible!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -353,13 +350,13 @@ if __name__ == "__main__":
|
||||
test_xvla_inference_reproducibility(policy, preprocessor)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✅ All tests passed!")
|
||||
print("All tests passed!")
|
||||
print("=" * 80)
|
||||
|
||||
cleanup_memory()
|
||||
except Exception as e:
|
||||
print("\n" + "=" * 80)
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
print(f"Test failed with error: {e}")
|
||||
print("=" * 80)
|
||||
cleanup_memory()
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user