mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
upgrade test, fix failing
This commit is contained in:
@@ -18,7 +18,6 @@
|
|||||||
# ruff: noqa: E402
|
# ruff: noqa: E402
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -27,20 +26,12 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Skip if transformers is not available
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
|
|
||||||
# Skip this entire module in CI
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="This test requires XVLA model access and is not meant for CI",
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||||
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||||
|
from tests.utils import require_package # noqa: E402
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
||||||
@@ -171,6 +162,8 @@ def preprocessor(xvla_components):
|
|||||||
return xvla_components[1]
|
return xvla_components[1]
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
@require_package("timm")
|
||||||
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||||
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -201,22 +194,24 @@ def test_xvla_preprocessor_alignment(policy, preprocessor):
|
|||||||
for key, expected_shape in expected_shapes.items():
|
for key, expected_shape in expected_shapes.items():
|
||||||
if key in lerobot_inputs:
|
if key in lerobot_inputs:
|
||||||
actual_shape = tuple(lerobot_inputs[key].shape)
|
actual_shape = tuple(lerobot_inputs[key].shape)
|
||||||
print(f"\n🔎 Key: {key}")
|
print(f"\nKey: {key}")
|
||||||
print(f" Expected shape: {expected_shape}")
|
print(f"Expected shape: {expected_shape}")
|
||||||
print(f" Actual shape: {actual_shape}")
|
print(f"Actual shape: {actual_shape}")
|
||||||
|
|
||||||
if actual_shape == expected_shape:
|
if actual_shape == expected_shape:
|
||||||
print(" ✔️ Shape matches!")
|
print("Shape matches!")
|
||||||
else:
|
else:
|
||||||
print(" ❌ Shape mismatch!")
|
print("Shape mismatch!")
|
||||||
|
|
||||||
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
||||||
else:
|
else:
|
||||||
print(f"\n⚠️ Key '{key}' not found in inputs!")
|
print(f"\nKey '{key}' not found in inputs!")
|
||||||
|
|
||||||
print("\n✅ All preprocessor outputs have correct shapes!")
|
print("\nAll preprocessor outputs have correct shapes!")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
@require_package("timm")
|
||||||
def test_xvla_action_generation(policy, preprocessor):
|
def test_xvla_action_generation(policy, preprocessor):
|
||||||
"""Test XVLA LeRobot implementation generates expected actions."""
|
"""Test XVLA LeRobot implementation generates expected actions."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -283,17 +278,19 @@ def test_xvla_action_generation(policy, preprocessor):
|
|||||||
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||||
for tol in tolerances:
|
for tol in tolerances:
|
||||||
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
||||||
status = "✔️" if is_close else "❌"
|
status = "Success" if is_close else "Failure"
|
||||||
print(f"{status} First 5 actions close (atol={tol}): {is_close}")
|
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
|
||||||
|
|
||||||
# Assert with reasonable tolerance
|
# Assert with reasonable tolerance
|
||||||
tolerance = 1e-3
|
tolerance = 1e-3
|
||||||
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
||||||
f"First 5 actions differ by more than tolerance ({tolerance})"
|
f"First 5 actions differ by more than tolerance ({tolerance})"
|
||||||
)
|
)
|
||||||
print(f"\n✅ Success: Actions match expected values within tolerance ({tolerance})!")
|
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
@require_package("timm")
|
||||||
def test_xvla_inference_reproducibility(policy, preprocessor):
|
def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||||
"""Test that XVLA inference is reproducible with the same seed."""
|
"""Test that XVLA inference is reproducible with the same seed."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -324,16 +321,16 @@ def test_xvla_inference_reproducibility(policy, preprocessor):
|
|||||||
print("\nComparing two runs:")
|
print("\nComparing two runs:")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||||
print("✔️ Inference is perfectly reproducible!")
|
print("Inference is perfectly reproducible!")
|
||||||
else:
|
else:
|
||||||
diff = torch.abs(actions_1 - actions_2)
|
diff = torch.abs(actions_1 - actions_2)
|
||||||
print("⚠️ Small differences detected:")
|
print("Small differences detected:")
|
||||||
print(f" Max diff: {diff.max().item():.6e}")
|
print(f" Max diff: {diff.max().item():.6e}")
|
||||||
print(f" Mean diff: {diff.mean().item():.6e}")
|
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||||
|
|
||||||
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||||
|
|
||||||
print("\n✅ Inference is reproducible!")
|
print("\nInference is reproducible!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -353,13 +350,13 @@ if __name__ == "__main__":
|
|||||||
test_xvla_inference_reproducibility(policy, preprocessor)
|
test_xvla_inference_reproducibility(policy, preprocessor)
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("✅ All tests passed!")
|
print("All tests passed!")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
cleanup_memory()
|
cleanup_memory()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print(f"❌ Test failed with error: {e}")
|
print(f"Test failed with error: {e}")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
cleanup_memory()
|
cleanup_memory()
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user