Files
lerobot/tests/policies/xvla/test_xvla_original_vs_lerobot.py
T
Jade Choghari 2044e52e36 update testing
2025-11-25 14:38:44 +01:00

342 lines
12 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import gc
import os
import random
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
# Skip if transformers is not available
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires XVLA model access and is not meant for CI",
)
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
# Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
# Expected values from original XVLA implementation (reference values)
EXPECTED_ACTIONS_SHAPE = (30, 20)
EXPECTED_ACTIONS_MEAN = 0.117606
EXPECTED_ACTIONS_STD = 0.245411
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_xvla(
from_pretrained: bool = False,
model_path: str = MODEL_PATH_LEROBOT,
) -> tuple[
Any, # Policy
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
if from_pretrained:
policy = XVLAPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
else:
config = XVLAConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_DIM,
chunk_size=DUMMY_ACTION_DIM,
device=DEVICE,
num_image_views=NUM_VIEWS,
) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
policy = XVLAPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_xvla_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
)
return policy, preprocessor, postprocessor
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 1
prompt = "Pick up the red block and place it in the bin"
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
# Then convert to [0, 1] float32 range for LeRobot
def fake_rgb(h, w):
arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
return t
batch = {
f"{OBS_IMAGES}.image": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
f"{OBS_IMAGES}.image2": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"task": [prompt for _ in range(batch_size)],
}
return batch
def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80)
print("Test: XVLA Preprocessor Outputs")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Preprocessing...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
print("\nVerifying preprocessor outputs:")
print("-" * 80)
# Expected shapes from tester.txt
expected_shapes = {
"domain_id": (1,),
"input_ids": (1, 50),
"proprio": (1, 20),
"image_mask": (1, 2),
"image_input": (1, 2, 3, 224, 224),
}
for key, expected_shape in expected_shapes.items():
if key in lerobot_inputs:
actual_shape = tuple(lerobot_inputs[key].shape)
print(f"\n🔎 Key: {key}")
print(f" Expected shape: {expected_shape}")
print(f" Actual shape: {actual_shape}")
if actual_shape == expected_shape:
print(" ✔️ Shape matches!")
else:
print(" ❌ Shape mismatch!")
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
else:
print(f"\n⚠️ Key '{key}' not found in inputs!")
print("\n✅ All preprocessor outputs have correct shapes!")
def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80)
print("Test: XVLA Action Generation Against Expected Values")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Running inference...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
# Reset seed for inference
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
print(f"LeRobot actions shape: {lerobot_actions.shape}")
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
print("\nExpected values (from original XVLA):")
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
print("\nAction Comparison:")
print("-" * 80)
# Compare shapes
actual_shape = tuple(lerobot_actions.shape)
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
)
print(f"✔️ Shape matches: {actual_shape}")
# Compare statistics
actual_mean = lerobot_actions.mean().item()
actual_std = lerobot_actions.std().item()
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
# Compare first 5 actions
actual_first_5 = lerobot_actions[0, :5]
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
print("\nFirst 5 actions comparison:")
print(f" Actual: {actual_first_5}")
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
print(f" Max diff: {first_5_diff.max().item():.6e}")
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
# Check with different tolerances
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
for tol in tolerances:
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
status = "✔️" if is_close else ""
print(f"{status} First 5 actions close (atol={tol}): {is_close}")
# Assert with reasonable tolerance
tolerance = 1e-3
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
f"First 5 actions differ by more than tolerance ({tolerance})"
)
print(f"\n✅ Success: Actions match expected values within tolerance ({tolerance})!")
def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80)
print("Test: XVLA Inference Reproducibility")
print("=" * 80)
print("\nCreating dummy data...")
batch = create_dummy_data()
# First inference
print("\n[Run 1] Running inference...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_1 = actions_1.squeeze(0).float().cpu()
# Second inference with same seed
print("\n[Run 2] Running inference with same seed...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_2 = actions_2.squeeze(0).float().cpu()
print("\nComparing two runs:")
print("-" * 80)
if torch.allclose(actions_1, actions_2, atol=1e-8):
print("✔️ Inference is perfectly reproducible!")
else:
diff = torch.abs(actions_1 - actions_2)
print("⚠️ Small differences detected:")
print(f" Max diff: {diff.max().item():.6e}")
print(f" Mean diff: {diff.mean().item():.6e}")
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
print("\n✅ Inference is reproducible!")
if __name__ == "__main__":
print("\n" + "=" * 80)
print("XVLA LeRobot Validation Test Suite")
print("=" * 80)
try:
# Initialize model once for all tests
print("\n[Setup] Instantiating LeRobot XVLA policy...")
policy, preprocessor, postprocessor = instantiate_lerobot_xvla(from_pretrained=True)
print("✔️ Model loaded successfully")
# Run all tests with the same model instance
test_xvla_preprocessor_alignment(policy, preprocessor)
test_xvla_action_generation(policy, preprocessor)
test_xvla_inference_reproducibility(policy, preprocessor)
print("\n" + "=" * 80)
print("✅ All tests passed!")
print("=" * 80)
cleanup_memory()
except Exception as e:
print("\n" + "=" * 80)
print(f"❌ Test failed with error: {e}")
print("=" * 80)
cleanup_memory()
raise