mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
cleanup tests
This commit is contained in:
@@ -1,22 +1,27 @@
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation."""
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if openpi or transformers is not available
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
|
||||
from transformers import AutoTokenizer
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
from tests.utils import require_cuda
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
@@ -320,7 +325,6 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
)
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
@@ -333,7 +337,7 @@ def test_pi0_original_vs_lerobot():
|
||||
batch = create_dummy_data()
|
||||
|
||||
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\n=== TEST 1: Each model with its own preprocessing ===")
|
||||
print("\nTEST 1: Each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
@@ -372,7 +376,7 @@ def test_pi0_original_vs_lerobot():
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
|
||||
print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===")
|
||||
print("\nTEST 2: Both models with LeRobot preprocessing (model comparison)")
|
||||
print("Creating observation for OpenPI using LeRobot's preprocessing...")
|
||||
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
|
||||
|
||||
@@ -397,8 +401,3 @@ def test_pi0_original_vs_lerobot():
|
||||
|
||||
# Add assertions for pytest
|
||||
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
|
||||
|
||||
print("\n=== SUMMARY ===")
|
||||
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
|
||||
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
|
||||
print("Both tests completed successfully!")
|
||||
|
||||
Reference in New Issue
Block a user