mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
update testing
This commit is contained in:
@@ -39,8 +39,6 @@ pytestmark = pytest.mark.skipif(
|
||||
reason="This test requires XVLA model access and is not meant for CI",
|
||||
)
|
||||
|
||||
from transformers import AutoModel, AutoProcessor # noqa: E402
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
|
||||
@@ -52,9 +50,14 @@ IMAGE_WIDTH = 224
|
||||
NUM_VIEWS = 2 # Number of camera views
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
|
||||
MODEL_PATH_ORIGINAL = "2toINF/X-VLA-WidowX"
|
||||
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
|
||||
|
||||
# Expected values from original XVLA implementation (reference values)
|
||||
EXPECTED_ACTIONS_SHAPE = (30, 20)
|
||||
EXPECTED_ACTIONS_MEAN = 0.117606
|
||||
EXPECTED_ACTIONS_STD = 0.245411
|
||||
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||
@@ -118,23 +121,6 @@ def instantiate_lerobot_xvla(
|
||||
return policy, preprocessor, postprocessor
|
||||
|
||||
|
||||
def instantiate_original_xvla(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH_ORIGINAL,
|
||||
):
|
||||
"""Instantiate original XVLA policy from the original implementation."""
|
||||
if from_pretrained:
|
||||
processor = AutoProcessor.from_pretrained(model_path, num_views=NUM_VIEWS, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
else:
|
||||
raise NotImplementedError("Non-pretrained XVLA instantiation not implemented yet")
|
||||
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 1
|
||||
@@ -161,215 +147,158 @@ def create_dummy_data(device=DEVICE):
|
||||
return batch
|
||||
|
||||
|
||||
def prepare_original_inputs(batch, processor, device=DEVICE):
|
||||
"""Prepare inputs for the original XVLA model."""
|
||||
# Convert images from [0, 1] to [0, 255] uint8 for processor
|
||||
image1 = (batch[f"{OBS_IMAGES}.image"]).byte()
|
||||
image2 = (batch[f"{OBS_IMAGES}.image2"]).byte()
|
||||
|
||||
# Get task instruction (use first one if batch)
|
||||
task_instruction = batch["task"][0] if isinstance(batch["task"], list) else batch["task"]
|
||||
|
||||
# Process images and text through original processor
|
||||
# The processor expects a list of images per sample
|
||||
processed_inputs = processor(
|
||||
[image1[0], image2[0]], # Process first sample only for now
|
||||
task_instruction,
|
||||
)
|
||||
|
||||
# Move to correct device and dtype
|
||||
dtype = torch.float32
|
||||
inputs = {
|
||||
k: v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device=device)
|
||||
for k, v in processed_inputs.items()
|
||||
}
|
||||
|
||||
# Add proprio and domain_id
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": batch[OBS_STATE][:1].to(device), # First sample only
|
||||
"domain_id": torch.tensor([LIBERO_DOMAIN_ID], dtype=torch.long, device=device),
|
||||
}
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def test_xvla_preprocessor_alignment():
|
||||
"""Test that LeRobot and Original XVLA preprocessors produce similar outputs."""
|
||||
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Preprocessor Alignment")
|
||||
print("Test: XVLA Preprocessor Outputs")
|
||||
print("=" * 80)
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
print("\n[LeRobot] Instantiating policy and preprocessor...")
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||
from_pretrained=True
|
||||
)
|
||||
|
||||
print("\n[Original] Instantiating model and processor...")
|
||||
original_model, original_processor = instantiate_original_xvla(from_pretrained=True)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
print("\n[LeRobot] Preprocessing...")
|
||||
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
|
||||
print("\n[Original] Preprocessing...")
|
||||
original_inputs = prepare_original_inputs(batch, original_processor)
|
||||
|
||||
print("\nComparing preprocessor outputs:")
|
||||
print("\nVerifying preprocessor outputs:")
|
||||
print("-" * 80)
|
||||
|
||||
# Compare common keys
|
||||
common_keys = set(lerobot_inputs.keys()) & set(original_inputs.keys())
|
||||
print(f"Common keys: {common_keys}")
|
||||
# Expected shapes from tester.txt
|
||||
expected_shapes = {
|
||||
"domain_id": (1,),
|
||||
"input_ids": (1, 50),
|
||||
"proprio": (1, 20),
|
||||
"image_mask": (1, 2),
|
||||
"image_input": (1, 2, 3, 224, 224),
|
||||
}
|
||||
|
||||
for key in common_keys:
|
||||
lerobot_tensor = lerobot_inputs[key]
|
||||
original_tensor = original_inputs[key]
|
||||
for key, expected_shape in expected_shapes.items():
|
||||
if key in lerobot_inputs:
|
||||
actual_shape = tuple(lerobot_inputs[key].shape)
|
||||
print(f"\n🔎 Key: {key}")
|
||||
print(f" Expected shape: {expected_shape}")
|
||||
print(f" Actual shape: {actual_shape}")
|
||||
|
||||
print(f"\n🔎 Key: {key}")
|
||||
print(f" LeRobot shape: {lerobot_tensor.shape}")
|
||||
print(f" Original shape: {original_tensor.shape}")
|
||||
|
||||
# Handle batch size difference (we only process first sample for original)
|
||||
if lerobot_tensor.shape[0] > original_tensor.shape[0]:
|
||||
lerobot_tensor = lerobot_tensor[:1]
|
||||
|
||||
if lerobot_tensor.shape == original_tensor.shape:
|
||||
if torch.allclose(lerobot_tensor, original_tensor, atol=1e-5, rtol=1e-5):
|
||||
print(" ✔️ Tensors are equal (allclose with atol=1e-5)")
|
||||
if actual_shape == expected_shape:
|
||||
print(" ✔️ Shape matches!")
|
||||
else:
|
||||
diff = torch.abs(lerobot_tensor - original_tensor)
|
||||
print(" ⚠️ Tensors differ")
|
||||
print(f" Max diff: {diff.max().item():.6e}")
|
||||
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||
print(f" Std diff: {diff.std().item():.6e}")
|
||||
print(" ❌ Shape mismatch!")
|
||||
|
||||
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
||||
else:
|
||||
print(" ⚠️ Shapes don't match after alignment")
|
||||
print(f"\n⚠️ Key '{key}' not found in inputs!")
|
||||
|
||||
cleanup_memory()
|
||||
print("\n✅ All preprocessor outputs have correct shapes!")
|
||||
|
||||
|
||||
def test_xvla_original_vs_lerobot_pretrained():
|
||||
"""Test XVLA original implementation vs LeRobot implementation with pretrained weights."""
|
||||
def test_xvla_action_generation(policy, preprocessor):
|
||||
"""Test XVLA LeRobot implementation generates expected actions."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Original vs LeRobot with Pretrained Weights (Inference)")
|
||||
print("Test: XVLA Action Generation Against Expected Values")
|
||||
print("=" * 80)
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
print("\n[LeRobot] Instantiating policy...")
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||
from_pretrained=True
|
||||
)
|
||||
|
||||
print("\n[Original] Instantiating model...")
|
||||
original_model, original_processor = instantiate_original_xvla(from_pretrained=True)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
|
||||
# Reset seed for inference
|
||||
torch.manual_seed(42)
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
|
||||
|
||||
print(f"LeRobot actions shape: {lerobot_actions.shape}")
|
||||
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
|
||||
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
|
||||
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
|
||||
|
||||
print("\n[Original] Running inference...")
|
||||
original_inputs = prepare_original_inputs(batch, original_processor)
|
||||
|
||||
# Reset seed for inference
|
||||
torch.manual_seed(42)
|
||||
with torch.no_grad():
|
||||
original_actions = original_model.generate_actions(**original_inputs, steps=10)
|
||||
original_actions = original_actions.squeeze(0).float().cpu()
|
||||
|
||||
print(f"Original actions shape: {original_actions.shape}")
|
||||
print(f"Original actions mean: {original_actions.mean().item():.6f}")
|
||||
print(f"Original actions std: {original_actions.std().item():.6f}")
|
||||
print("\nExpected values (from original XVLA):")
|
||||
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
|
||||
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
|
||||
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
|
||||
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
|
||||
|
||||
print("\nAction Comparison:")
|
||||
print("-" * 80)
|
||||
|
||||
# Compare actions
|
||||
if lerobot_actions.shape == original_actions.shape:
|
||||
diff = torch.abs(lerobot_actions - original_actions)
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
# Compare shapes
|
||||
actual_shape = tuple(lerobot_actions.shape)
|
||||
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
|
||||
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
|
||||
)
|
||||
print(f"✔️ Shape matches: {actual_shape}")
|
||||
|
||||
print(f"Max absolute difference: {max_diff:.6e}")
|
||||
print(f"Mean absolute difference: {mean_diff:.6e}")
|
||||
print(
|
||||
f"Relative difference: {(mean_diff / (torch.abs(original_actions).mean().item() + 1e-8) * 100):.2f}%"
|
||||
)
|
||||
# Compare statistics
|
||||
actual_mean = lerobot_actions.mean().item()
|
||||
actual_std = lerobot_actions.std().item()
|
||||
|
||||
# Check with different tolerances
|
||||
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||
for tol in tolerances:
|
||||
is_close = torch.allclose(lerobot_actions, original_actions, atol=tol)
|
||||
status = "✔️" if is_close else "❌"
|
||||
print(f"{status} Actions close (atol={tol}): {is_close}")
|
||||
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
|
||||
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
|
||||
|
||||
# Assert with reasonable tolerance
|
||||
tolerance = 1e-3
|
||||
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
|
||||
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6e}"
|
||||
)
|
||||
print(f"\n✅ Success: Actions match within tolerance ({tolerance})!")
|
||||
else:
|
||||
print(f"⚠️ Shape mismatch: LeRobot {lerobot_actions.shape} vs Original {original_actions.shape}")
|
||||
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
|
||||
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
|
||||
|
||||
cleanup_memory()
|
||||
# Compare first 5 actions
|
||||
actual_first_5 = lerobot_actions[0, :5]
|
||||
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
|
||||
|
||||
print("\nFirst 5 actions comparison:")
|
||||
print(f" Actual: {actual_first_5}")
|
||||
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
|
||||
print(f" Max diff: {first_5_diff.max().item():.6e}")
|
||||
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
|
||||
|
||||
# Check with different tolerances
|
||||
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||
for tol in tolerances:
|
||||
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
||||
status = "✔️" if is_close else "❌"
|
||||
print(f"{status} First 5 actions close (atol={tol}): {is_close}")
|
||||
|
||||
# Assert with reasonable tolerance
|
||||
tolerance = 1e-3
|
||||
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
||||
f"First 5 actions differ by more than tolerance ({tolerance})"
|
||||
)
|
||||
print(f"\n✅ Success: Actions match expected values within tolerance ({tolerance})!")
|
||||
|
||||
|
||||
def test_xvla_inference_reproducibility():
|
||||
def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||
"""Test that XVLA inference is reproducible with the same seed."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Inference Reproducibility")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[LeRobot] Instantiating policy...")
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||
from_pretrained=True
|
||||
)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
# First inference
|
||||
print("\n[Run 1] Running inference...")
|
||||
set_seed_all(42)
|
||||
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
with torch.no_grad():
|
||||
actions_1 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_1 = actions_1.squeeze(0).float().cpu()
|
||||
|
||||
# Second inference with same seed
|
||||
print("\n[Run 2] Running inference with same seed...")
|
||||
set_seed_all(42)
|
||||
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
with torch.no_grad():
|
||||
actions_2 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_2 = actions_2.squeeze(0).float().cpu()
|
||||
|
||||
print("\nComparing two runs:")
|
||||
print("-" * 80)
|
||||
|
||||
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||
print("✔️ Inference is perfectly reproducible!")
|
||||
else:
|
||||
@@ -380,24 +309,33 @@ def test_xvla_inference_reproducibility():
|
||||
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||
|
||||
cleanup_memory()
|
||||
print("\n✅ Inference is reproducible!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 80)
|
||||
print("XVLA Original vs LeRobot Comparison Test Suite")
|
||||
print("XVLA LeRobot Validation Test Suite")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
test_xvla_preprocessor_alignment()
|
||||
test_xvla_original_vs_lerobot_pretrained()
|
||||
test_xvla_inference_reproducibility()
|
||||
# Initialize model once for all tests
|
||||
print("\n[Setup] Instantiating LeRobot XVLA policy...")
|
||||
policy, preprocessor, postprocessor = instantiate_lerobot_xvla(from_pretrained=True)
|
||||
print("✔️ Model loaded successfully")
|
||||
|
||||
# Run all tests with the same model instance
|
||||
test_xvla_preprocessor_alignment(policy, preprocessor)
|
||||
test_xvla_action_generation(policy, preprocessor)
|
||||
test_xvla_inference_reproducibility(policy, preprocessor)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✅ All tests passed!")
|
||||
print("=" * 80)
|
||||
|
||||
cleanup_memory()
|
||||
except Exception as e:
|
||||
print("\n" + "=" * 80)
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
print("=" * 80)
|
||||
cleanup_memory()
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user