mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
add testing
This commit is contained in:
@@ -0,0 +1,409 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Skip if transformers is not available
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
# Skip this entire module in CI
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
|
reason="This test requires XVLA model access and is not meant for CI",
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import AutoModel, AutoProcessor # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig # noqa: E402
|
||||||
|
from lerobot.envs.factory import make_env_config # noqa: E402
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: E402
|
||||||
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
||||||
|
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
|
||||||
|
IMAGE_HEIGHT = 224
|
||||||
|
IMAGE_WIDTH = 224
|
||||||
|
NUM_VIEWS = 2 # Number of camera views
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
MODEL_PATH_LEROBOT = "lerobot/xvla-base"
|
||||||
|
MODEL_PATH_ORIGINAL = "2toINF/X-VLA-Pt"
|
||||||
|
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_memory():
|
||||||
|
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||||
|
print("\nCleaning up memory...")
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
print("Memory cleanup complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_all(seed: int):
|
||||||
|
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# Set deterministic behavior
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_lerobot_xvla(
|
||||||
|
from_pretrained: bool = False,
|
||||||
|
model_path: str = MODEL_PATH_LEROBOT,
|
||||||
|
) -> tuple[
|
||||||
|
Any, # Policy
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
|
||||||
|
if from_pretrained:
|
||||||
|
cfg = PreTrainedConfig.from_pretrained(model_path)
|
||||||
|
cfg.pretrained_path = model_path
|
||||||
|
else:
|
||||||
|
# For non-pretrained, we'd need to create a config from scratch
|
||||||
|
raise NotImplementedError("Non-pretrained XVLA instantiation not implemented yet")
|
||||||
|
|
||||||
|
cfg.device = DEVICE
|
||||||
|
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||||
|
|
||||||
|
policy = make_policy(
|
||||||
|
cfg=cfg,
|
||||||
|
env_cfg=env_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.to(DEVICE)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
preprocessor_overrides = {
|
||||||
|
"device_processor": {"device": str(cfg.device)},
|
||||||
|
}
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg,
|
||||||
|
pretrained_path=cfg.pretrained_path,
|
||||||
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy, preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_original_xvla(
|
||||||
|
from_pretrained: bool = False,
|
||||||
|
model_path: str = MODEL_PATH_ORIGINAL,
|
||||||
|
):
|
||||||
|
"""Instantiate original XVLA policy from the original implementation."""
|
||||||
|
if from_pretrained:
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
num_views=NUM_VIEWS,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-pretrained XVLA instantiation not implemented yet")
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_data(device=DEVICE):
|
||||||
|
"""Create dummy data for testing both implementations."""
|
||||||
|
batch_size = 2
|
||||||
|
prompt = "Pick up the red block and place it in the bin"
|
||||||
|
|
||||||
|
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
|
||||||
|
# Then convert to [0, 1] float32 range for LeRobot
|
||||||
|
def fake_rgb(H, W):
|
||||||
|
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
||||||
|
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||||
|
t = t.float() / 255.0 # Normalize to [0, 1]
|
||||||
|
return t
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
f"{OBS_IMAGES}.image": torch.stack([fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]).to(device),
|
||||||
|
f"{OBS_IMAGES}.image2": torch.stack([fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]).to(device),
|
||||||
|
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||||
|
"task": [prompt for _ in range(batch_size)],
|
||||||
|
}
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_original_inputs(batch, processor, device=DEVICE):
|
||||||
|
"""Prepare inputs for the original XVLA model."""
|
||||||
|
batch_size = batch[OBS_STATE].shape[0]
|
||||||
|
|
||||||
|
# Convert images from [0, 1] to [0, 255] uint8 for processor
|
||||||
|
image1 = (batch[f"{OBS_IMAGES}.image"] * 255).byte()
|
||||||
|
image2 = (batch[f"{OBS_IMAGES}.image2"] * 255).byte()
|
||||||
|
|
||||||
|
# Get task instruction (use first one if batch)
|
||||||
|
task_instruction = batch["task"][0] if isinstance(batch["task"], list) else batch["task"]
|
||||||
|
|
||||||
|
# Process images and text through original processor
|
||||||
|
# The processor expects a list of images per sample
|
||||||
|
processed_inputs = processor(
|
||||||
|
[image1[0], image2[0]], # Process first sample only for now
|
||||||
|
task_instruction
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to correct device and dtype
|
||||||
|
dtype = torch.float32
|
||||||
|
inputs = {k: v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device=device)
|
||||||
|
for k, v in processed_inputs.items()}
|
||||||
|
|
||||||
|
# Add proprio and domain_id
|
||||||
|
inputs.update({
|
||||||
|
"proprio": batch[OBS_STATE][:1].to(device), # First sample only
|
||||||
|
"domain_id": torch.tensor([LIBERO_DOMAIN_ID], dtype=torch.long, device=device),
|
||||||
|
})
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def test_xvla_preprocessor_alignment():
|
||||||
|
"""Test that LeRobot and Original XVLA preprocessors produce similar outputs."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Preprocessor Alignment")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
set_seed_all(42)
|
||||||
|
|
||||||
|
print("\n[LeRobot] Instantiating policy and preprocessor...")
|
||||||
|
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||||
|
from_pretrained=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n[Original] Instantiating model and processor...")
|
||||||
|
original_model, original_processor = instantiate_original_xvla(from_pretrained=True)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
print("\n[LeRobot] Preprocessing...")
|
||||||
|
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||||
|
|
||||||
|
print("\n[Original] Preprocessing...")
|
||||||
|
original_inputs = prepare_original_inputs(batch, original_processor)
|
||||||
|
|
||||||
|
print("\nComparing preprocessor outputs:")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Compare common keys
|
||||||
|
common_keys = set(lerobot_inputs.keys()) & set(original_inputs.keys())
|
||||||
|
print(f"Common keys: {common_keys}")
|
||||||
|
|
||||||
|
for key in common_keys:
|
||||||
|
lerobot_tensor = lerobot_inputs[key]
|
||||||
|
original_tensor = original_inputs[key]
|
||||||
|
|
||||||
|
print(f"\n🔎 Key: {key}")
|
||||||
|
print(f" LeRobot shape: {lerobot_tensor.shape}")
|
||||||
|
print(f" Original shape: {original_tensor.shape}")
|
||||||
|
|
||||||
|
# Handle batch size difference (we only process first sample for original)
|
||||||
|
if lerobot_tensor.shape[0] > original_tensor.shape[0]:
|
||||||
|
lerobot_tensor = lerobot_tensor[:1]
|
||||||
|
|
||||||
|
if lerobot_tensor.shape == original_tensor.shape:
|
||||||
|
if torch.allclose(lerobot_tensor, original_tensor, atol=1e-5, rtol=1e-5):
|
||||||
|
print(" ✔️ Tensors are equal (allclose with atol=1e-5)")
|
||||||
|
else:
|
||||||
|
diff = torch.abs(lerobot_tensor - original_tensor)
|
||||||
|
print(" ⚠️ Tensors differ")
|
||||||
|
print(f" Max diff: {diff.max().item():.6e}")
|
||||||
|
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||||
|
print(f" Std diff: {diff.std().item():.6e}")
|
||||||
|
else:
|
||||||
|
print(" ⚠️ Shapes don't match after alignment")
|
||||||
|
|
||||||
|
cleanup_memory()
|
||||||
|
|
||||||
|
|
||||||
|
def test_xvla_original_vs_lerobot_pretrained():
|
||||||
|
"""Test XVLA original implementation vs LeRobot implementation with pretrained weights."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Original vs LeRobot with Pretrained Weights (Inference)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
set_seed_all(42)
|
||||||
|
|
||||||
|
print("\n[LeRobot] Instantiating policy...")
|
||||||
|
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||||
|
from_pretrained=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n[Original] Instantiating model...")
|
||||||
|
original_model, original_processor = instantiate_original_xvla(from_pretrained=True)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
print("\n[LeRobot] Running inference...")
|
||||||
|
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||||
|
|
||||||
|
# Reset seed for inference
|
||||||
|
torch.manual_seed(42)
|
||||||
|
with torch.no_grad():
|
||||||
|
lerobot_actions = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
print(f"LeRobot actions shape: {lerobot_actions.shape}")
|
||||||
|
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
|
||||||
|
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
|
||||||
|
|
||||||
|
print("\n[Original] Running inference...")
|
||||||
|
original_inputs = prepare_original_inputs(batch, original_processor)
|
||||||
|
|
||||||
|
# Reset seed for inference
|
||||||
|
torch.manual_seed(42)
|
||||||
|
with torch.no_grad():
|
||||||
|
original_actions = original_model.generate_actions(**original_inputs, steps=10)
|
||||||
|
original_actions = original_actions.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
print(f"Original actions shape: {original_actions.shape}")
|
||||||
|
print(f"Original actions mean: {original_actions.mean().item():.6f}")
|
||||||
|
print(f"Original actions std: {original_actions.std().item():.6f}")
|
||||||
|
|
||||||
|
print("\nAction Comparison:")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Compare actions
|
||||||
|
if lerobot_actions.shape == original_actions.shape:
|
||||||
|
diff = torch.abs(lerobot_actions - original_actions)
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
print(f"Max absolute difference: {max_diff:.6e}")
|
||||||
|
print(f"Mean absolute difference: {mean_diff:.6e}")
|
||||||
|
print(f"Relative difference: {(mean_diff / (torch.abs(original_actions).mean().item() + 1e-8) * 100):.2f}%")
|
||||||
|
|
||||||
|
# Check with different tolerances
|
||||||
|
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||||
|
for tol in tolerances:
|
||||||
|
is_close = torch.allclose(lerobot_actions, original_actions, atol=tol)
|
||||||
|
status = "✔️" if is_close else "❌"
|
||||||
|
print(f"{status} Actions close (atol={tol}): {is_close}")
|
||||||
|
|
||||||
|
# Assert with reasonable tolerance
|
||||||
|
tolerance = 1e-3
|
||||||
|
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
|
||||||
|
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6e}"
|
||||||
|
)
|
||||||
|
print(f"\n✅ Success: Actions match within tolerance ({tolerance})!")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Shape mismatch: LeRobot {lerobot_actions.shape} vs Original {original_actions.shape}")
|
||||||
|
|
||||||
|
cleanup_memory()
|
||||||
|
|
||||||
|
|
||||||
|
def test_xvla_inference_reproducibility():
|
||||||
|
"""Test that XVLA inference is reproducible with the same seed."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Inference Reproducibility")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
print("\n[LeRobot] Instantiating policy...")
|
||||||
|
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_xvla(
|
||||||
|
from_pretrained=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
# First inference
|
||||||
|
print("\n[Run 1] Running inference...")
|
||||||
|
set_seed_all(42)
|
||||||
|
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||||
|
with torch.no_grad():
|
||||||
|
actions_1 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
actions_1 = actions_1.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
# Second inference with same seed
|
||||||
|
print("\n[Run 2] Running inference with same seed...")
|
||||||
|
set_seed_all(42)
|
||||||
|
lerobot_observation = lerobot_preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = lerobot_policy._build_model_inputs(lerobot_observation)
|
||||||
|
with torch.no_grad():
|
||||||
|
actions_2 = lerobot_policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
actions_2 = actions_2.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
print("\nComparing two runs:")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||||
|
print("✔️ Inference is perfectly reproducible!")
|
||||||
|
else:
|
||||||
|
diff = torch.abs(actions_1 - actions_2)
|
||||||
|
print(f"⚠️ Small differences detected:")
|
||||||
|
print(f" Max diff: {diff.max().item():.6e}")
|
||||||
|
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||||
|
|
||||||
|
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||||
|
|
||||||
|
cleanup_memory()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("XVLA Original vs LeRobot Comparison Test Suite")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_xvla_preprocessor_alignment()
|
||||||
|
test_xvla_original_vs_lerobot_pretrained()
|
||||||
|
test_xvla_inference_reproducibility()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("✅ All tests passed!")
|
||||||
|
print("=" * 80)
|
||||||
|
except Exception as e:
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"❌ Test failed with error: {e}")
|
||||||
|
print("=" * 80)
|
||||||
|
raise
|
||||||
|
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from xvla.models.modeling_xvla import XVLA
|
||||||
|
|
||||||
|
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.envs.factory import make_env_config
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
random.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
observation_height: int = 224
|
||||||
|
observation_width: int = 224 # todo: jadechoghari, image size is different for the two models
|
||||||
|
# create an observation dict
|
||||||
|
OBS = {
|
||||||
|
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
|
||||||
|
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||||
|
OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string
|
||||||
|
"task": "put the object in the box",
|
||||||
|
}
|
||||||
|
|
||||||
|
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||||
|
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def fake_rgb(H, W):
|
||||||
|
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
||||||
|
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||||
|
t = t.unsqueeze(0).float()
|
||||||
|
# normalize pixel to imagenet
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width)
|
||||||
|
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||||
|
|
||||||
|
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||||
|
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||||
|
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||||
|
policy = make_policy(
|
||||||
|
cfg=cfg,
|
||||||
|
env_cfg=env_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
preprocessor_overrides = {
|
||||||
|
"device_processor": {"device": str(cfg.device)},
|
||||||
|
}
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg,
|
||||||
|
pretrained_path=cfg.pretrained_path,
|
||||||
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
observation = preprocessor(OBS)
|
||||||
|
inputs = policy._build_model_inputs(observation)
|
||||||
|
|
||||||
|
|
||||||
|
#### now the og model ###########################################################
|
||||||
|
from xvla.models.processing_xvla import XVLAProcessor
|
||||||
|
|
||||||
|
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||||
|
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||||
|
domain_id = torch.tensor([3], dtype=torch.long)
|
||||||
|
inputs.update(
|
||||||
|
{
|
||||||
|
"proprio": OBS[OBS_STATE].to("cuda"),
|
||||||
|
"domain_id": domain_id.to("cuda"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the preprocessor
|
||||||
|
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||||
|
a = inputs[k]
|
||||||
|
b = inputs_1[k].to("cuda")
|
||||||
|
|
||||||
|
print(f"\n🔎 Key: {k}")
|
||||||
|
|
||||||
|
# Check shape
|
||||||
|
print(" shape:", a.shape, b.shape)
|
||||||
|
|
||||||
|
# Check if close
|
||||||
|
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||||
|
print(" ✔️ tensors are equal (allclose)")
|
||||||
|
else:
|
||||||
|
diff = torch.abs(a - b)
|
||||||
|
print(" ❌ tensors differ")
|
||||||
|
print(" max diff:", diff.max().item())
|
||||||
|
print(" mean diff:", diff.mean().item())
|
||||||
|
|
||||||
|
|
||||||
|
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||||
|
model.eval()
|
||||||
|
model.to("cuda")
|
||||||
|
|
||||||
|
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||||
|
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||||
|
|
||||||
|
# np all close
|
||||||
|
print(np.allclose(action, action_1, atol=1e-2, rtol=1e-2))
|
||||||
|
print("max diff:", np.max(np.abs(action - action_1)))
|
||||||
|
print("mean diff:", np.mean(np.abs(action - action_1)))
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from xvla.models.configuration_xvla import XVLAConfig
|
||||||
|
from xvla.models.modeling_xvla import XVLA
|
||||||
|
from xvla.models.processor_xvla import XVLAProcessor
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.envs.factory import make_env_config
|
||||||
|
from lerobot.policies.factory import make_policy
|
||||||
|
|
||||||
|
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
|
||||||
|
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||||
|
model.eval()
|
||||||
|
model.to("cuda")
|
||||||
|
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero")
|
||||||
|
# /raid/jade/models/xvla-libero
|
||||||
|
# seet seed
|
||||||
|
torch.manual_seed(42)
|
||||||
|
random.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||||
|
images = []
|
||||||
|
for _ in range(num_images):
|
||||||
|
# Random RGB image
|
||||||
|
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
|
||||||
|
img = Image.fromarray(arr)
|
||||||
|
images.append(img)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
images = make_random_pil_images()
|
||||||
|
language_instruction = "This is a random image"
|
||||||
|
# Multimodal preprocessing by processor
|
||||||
|
inputs = processor(images, language_instruction)
|
||||||
|
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||||
|
raise ValueError("Processor did not return the expected keys.")
|
||||||
|
|
||||||
|
proprio = torch.randn(1, 20)
|
||||||
|
domain_id = torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
# Align to model's device/dtype
|
||||||
|
device = model.device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
|
||||||
|
|
||||||
|
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not isinstance(t, torch.Tensor):
|
||||||
|
t = torch.as_tensor(t)
|
||||||
|
# cast floats to model dtype, keep integral/bool as-is
|
||||||
|
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||||
|
inputs.update(
|
||||||
|
{
|
||||||
|
"proprio": to_model(proprio),
|
||||||
|
"domain_id": domain_id.to(device),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
#### now for lerobot model #####################################################
|
||||||
|
|
||||||
|
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||||
|
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||||
|
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||||
|
policy = make_policy(cfg=cfg, env_cfg=env_cfg)
|
||||||
|
policy.eval()
|
||||||
|
policy.to("cuda")
|
||||||
|
|
||||||
|
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||||
Reference in New Issue
Block a user