diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index 579c0a1cc..b8727ba6c 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -39,8 +39,6 @@ pytestmark = pytest.mark.skipif( reason="This test requires XVLA model access and is not meant for CI", ) -from transformers import AutoModel, AutoProcessor # noqa: E402 - from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 @@ -52,9 +50,14 @@ IMAGE_WIDTH = 224 NUM_VIEWS = 2 # Number of camera views DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_PATH_LEROBOT = "lerobot/xvla-widowx" -MODEL_PATH_ORIGINAL = "2toINF/X-VLA-WidowX" LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes +# Expected values from original XVLA implementation (reference values) +EXPECTED_ACTIONS_SHAPE = (30, 20) +EXPECTED_ACTIONS_MEAN = 0.117606 +EXPECTED_ACTIONS_STD = 0.245411 +EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653]) + def cleanup_memory(): """Clean up GPU/MPS memory to prevent OOM errors between tests.""" @@ -118,23 +121,6 @@ def instantiate_lerobot_xvla( return policy, preprocessor, postprocessor -def instantiate_original_xvla( - from_pretrained: bool = False, - model_path: str = MODEL_PATH_ORIGINAL, -): - """Instantiate original XVLA policy from the original implementation.""" - if from_pretrained: - processor = AutoProcessor.from_pretrained(model_path, num_views=NUM_VIEWS, trust_remote_code=True) - model = AutoModel.from_pretrained(model_path, trust_remote_code=True) - else: - raise NotImplementedError("Non-pretrained XVLA instantiation not implemented yet") - - model.to(DEVICE) - model.eval() - - return model, processor - - def create_dummy_data(device=DEVICE): """Create dummy data for testing both implementations.""" batch_size = 1 @@ -161,215 +147,158 @@ def create_dummy_data(device=DEVICE): return batch -def prepare_original_inputs(batch, processor, device=DEVICE): - """Prepare inputs for the original XVLA model.""" - # Convert images from [0, 1] to [0, 255] uint8 for processor - image1 = (batch[f"{OBS_IMAGES}.image"]).byte() - image2 = (batch[f"{OBS_IMAGES}.image2"]).byte() - - # Get task instruction (use first one if batch) - task_instruction = batch["task"][0] if isinstance(batch["task"], list) else batch["task"] - - # Process images and text through original processor - # The processor expects a list of images per sample - processed_inputs = processor( - [image1[0], image2[0]], # Process first sample only for now - task_instruction, - ) - - # Move to correct device and dtype - dtype = torch.float32 - inputs = { - k: v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device=device) - for k, v in processed_inputs.items() - } - - # Add proprio and domain_id - inputs.update( - { - "proprio": batch[OBS_STATE][:1].to(device), # First sample only - "domain_id": torch.tensor([LIBERO_DOMAIN_ID], dtype=torch.long, device=device), - } - ) - - return inputs - - -def test_xvla_preprocessor_alignment(): - """Test that LeRobot and Original XVLA preprocessors produce similar outputs.""" +def test_xvla_preprocessor_alignment(policy, preprocessor): + """Test that LeRobot XVLA preprocessor produces expected outputs.""" print("\n" + "=" * 80) - print("Test: XVLA Preprocessor Alignment") + print("Test: XVLA Preprocessor Outputs") print("=" * 80) set_seed_all(42) - print("\n[LeRobot] Instantiating policy and preprocessor...") - lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla( - from_pretrained=True - ) - - print("\n[Original] Instantiating model and processor...") - original_model, original_processor = instantiate_original_xvla(from_pretrained=True) - print("\nCreating dummy data...") batch = create_dummy_data() print("\n[LeRobot] Preprocessing...") - lerobot_observation = lerobot_preprocessor(deepcopy(batch)) - lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation) + lerobot_observation = preprocessor(deepcopy(batch)) + lerobot_inputs = policy._build_model_inputs(lerobot_observation) - print("\n[Original] Preprocessing...") - original_inputs = prepare_original_inputs(batch, original_processor) - - print("\nComparing preprocessor outputs:") + print("\nVerifying preprocessor outputs:") print("-" * 80) - # Compare common keys - common_keys = set(lerobot_inputs.keys()) & set(original_inputs.keys()) - print(f"Common keys: {common_keys}") + # Expected shapes from tester.txt + expected_shapes = { + "domain_id": (1,), + "input_ids": (1, 50), + "proprio": (1, 20), + "image_mask": (1, 2), + "image_input": (1, 2, 3, 224, 224), + } - for key in common_keys: - lerobot_tensor = lerobot_inputs[key] - original_tensor = original_inputs[key] + for key, expected_shape in expected_shapes.items(): + if key in lerobot_inputs: + actual_shape = tuple(lerobot_inputs[key].shape) + print(f"\nšŸ”Ž Key: {key}") + print(f" Expected shape: {expected_shape}") + print(f" Actual shape: {actual_shape}") - print(f"\nšŸ”Ž Key: {key}") - print(f" LeRobot shape: {lerobot_tensor.shape}") - print(f" Original shape: {original_tensor.shape}") - - # Handle batch size difference (we only process first sample for original) - if lerobot_tensor.shape[0] > original_tensor.shape[0]: - lerobot_tensor = lerobot_tensor[:1] - - if lerobot_tensor.shape == original_tensor.shape: - if torch.allclose(lerobot_tensor, original_tensor, atol=1e-5, rtol=1e-5): - print(" āœ”ļø Tensors are equal (allclose with atol=1e-5)") + if actual_shape == expected_shape: + print(" āœ”ļø Shape matches!") else: - diff = torch.abs(lerobot_tensor - original_tensor) - print(" āš ļø Tensors differ") - print(f" Max diff: {diff.max().item():.6e}") - print(f" Mean diff: {diff.mean().item():.6e}") - print(f" Std diff: {diff.std().item():.6e}") + print(" āŒ Shape mismatch!") + + assert actual_shape == expected_shape, f"Shape mismatch for {key}" else: - print(" āš ļø Shapes don't match after alignment") + print(f"\nāš ļø Key '{key}' not found in inputs!") - cleanup_memory() + print("\nāœ… All preprocessor outputs have correct shapes!") -def test_xvla_original_vs_lerobot_pretrained(): - """Test XVLA original implementation vs LeRobot implementation with pretrained weights.""" +def test_xvla_action_generation(policy, preprocessor): + """Test XVLA LeRobot implementation generates expected actions.""" print("\n" + "=" * 80) - print("Test: XVLA Original vs LeRobot with Pretrained Weights (Inference)") + print("Test: XVLA Action Generation Against Expected Values") print("=" * 80) set_seed_all(42) - print("\n[LeRobot] Instantiating policy...") - lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla( - from_pretrained=True - ) - - print("\n[Original] Instantiating model...") - original_model, original_processor = instantiate_original_xvla(from_pretrained=True) - print("\nCreating dummy data...") batch = create_dummy_data() print("\n[LeRobot] Running inference...") - lerobot_observation = lerobot_preprocessor(deepcopy(batch)) - lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation) + lerobot_observation = preprocessor(deepcopy(batch)) + lerobot_inputs = policy._build_model_inputs(lerobot_observation) # Reset seed for inference torch.manual_seed(42) with torch.no_grad(): - lerobot_actions = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10) + lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10) lerobot_actions = lerobot_actions.squeeze(0).float().cpu() print(f"LeRobot actions shape: {lerobot_actions.shape}") print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}") print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}") + print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}") - print("\n[Original] Running inference...") - original_inputs = prepare_original_inputs(batch, original_processor) - - # Reset seed for inference - torch.manual_seed(42) - with torch.no_grad(): - original_actions = original_model.generate_actions(**original_inputs, steps=10) - original_actions = original_actions.squeeze(0).float().cpu() - - print(f"Original actions shape: {original_actions.shape}") - print(f"Original actions mean: {original_actions.mean().item():.6f}") - print(f"Original actions std: {original_actions.std().item():.6f}") + print("\nExpected values (from original XVLA):") + print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}") + print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}") + print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}") + print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}") print("\nAction Comparison:") print("-" * 80) - # Compare actions - if lerobot_actions.shape == original_actions.shape: - diff = torch.abs(lerobot_actions - original_actions) - max_diff = diff.max().item() - mean_diff = diff.mean().item() + # Compare shapes + actual_shape = tuple(lerobot_actions.shape) + assert actual_shape == EXPECTED_ACTIONS_SHAPE, ( + f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}" + ) + print(f"āœ”ļø Shape matches: {actual_shape}") - print(f"Max absolute difference: {max_diff:.6e}") - print(f"Mean absolute difference: {mean_diff:.6e}") - print( - f"Relative difference: {(mean_diff / (torch.abs(original_actions).mean().item() + 1e-8) * 100):.2f}%" - ) + # Compare statistics + actual_mean = lerobot_actions.mean().item() + actual_std = lerobot_actions.std().item() - # Check with different tolerances - tolerances = [1e-5, 1e-4, 1e-3, 1e-2] - for tol in tolerances: - is_close = torch.allclose(lerobot_actions, original_actions, atol=tol) - status = "āœ”ļø" if is_close else "āŒ" - print(f"{status} Actions close (atol={tol}): {is_close}") + mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN) + std_diff = abs(actual_std - EXPECTED_ACTIONS_STD) - # Assert with reasonable tolerance - tolerance = 1e-3 - assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), ( - f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6e}" - ) - print(f"\nāœ… Success: Actions match within tolerance ({tolerance})!") - else: - print(f"āš ļø Shape mismatch: LeRobot {lerobot_actions.shape} vs Original {original_actions.shape}") + print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})") + print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})") - cleanup_memory() + # Compare first 5 actions + actual_first_5 = lerobot_actions[0, :5] + first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5) + + print("\nFirst 5 actions comparison:") + print(f" Actual: {actual_first_5}") + print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}") + print(f" Max diff: {first_5_diff.max().item():.6e}") + print(f" Mean diff: {first_5_diff.mean().item():.6e}") + + # Check with different tolerances + tolerances = [1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol) + status = "āœ”ļø" if is_close else "āŒ" + print(f"{status} First 5 actions close (atol={tol}): {is_close}") + + # Assert with reasonable tolerance + tolerance = 1e-3 + assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), ( + f"First 5 actions differ by more than tolerance ({tolerance})" + ) + print(f"\nāœ… Success: Actions match expected values within tolerance ({tolerance})!") -def test_xvla_inference_reproducibility(): +def test_xvla_inference_reproducibility(policy, preprocessor): """Test that XVLA inference is reproducible with the same seed.""" print("\n" + "=" * 80) print("Test: XVLA Inference Reproducibility") print("=" * 80) - print("\n[LeRobot] Instantiating policy...") - lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla( - from_pretrained=True - ) - print("\nCreating dummy data...") batch = create_dummy_data() # First inference print("\n[Run 1] Running inference...") set_seed_all(42) - lerobot_observation = lerobot_preprocessor(deepcopy(batch)) - lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation) + lerobot_observation = preprocessor(deepcopy(batch)) + lerobot_inputs = policy._build_model_inputs(lerobot_observation) with torch.no_grad(): - actions_1 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10) + actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10) actions_1 = actions_1.squeeze(0).float().cpu() # Second inference with same seed print("\n[Run 2] Running inference with same seed...") set_seed_all(42) - lerobot_observation = lerobot_preprocessor(deepcopy(batch)) - lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation) + lerobot_observation = preprocessor(deepcopy(batch)) + lerobot_inputs = policy._build_model_inputs(lerobot_observation) with torch.no_grad(): - actions_2 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10) + actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10) actions_2 = actions_2.squeeze(0).float().cpu() print("\nComparing two runs:") print("-" * 80) - if torch.allclose(actions_1, actions_2, atol=1e-8): print("āœ”ļø Inference is perfectly reproducible!") else: @@ -380,24 +309,33 @@ def test_xvla_inference_reproducibility(): assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!" - cleanup_memory() + print("\nāœ… Inference is reproducible!") if __name__ == "__main__": print("\n" + "=" * 80) - print("XVLA Original vs LeRobot Comparison Test Suite") + print("XVLA LeRobot Validation Test Suite") print("=" * 80) try: - test_xvla_preprocessor_alignment() - test_xvla_original_vs_lerobot_pretrained() - test_xvla_inference_reproducibility() + # Initialize model once for all tests + print("\n[Setup] Instantiating LeRobot XVLA policy...") + policy, preprocessor, postprocessor = instantiate_lerobot_xvla(from_pretrained=True) + print("āœ”ļø Model loaded successfully") + + # Run all tests with the same model instance + test_xvla_preprocessor_alignment(policy, preprocessor) + test_xvla_action_generation(policy, preprocessor) + test_xvla_inference_reproducibility(policy, preprocessor) print("\n" + "=" * 80) print("āœ… All tests passed!") print("=" * 80) + + cleanup_memory() except Exception as e: print("\n" + "=" * 80) print(f"āŒ Test failed with error: {e}") print("=" * 80) + cleanup_memory() raise