mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
feat(policies): Add X-VLA (#2405)
* first commit * more fixes * add franka action * update testing script * add changes * update files * logits matching * add imagenet as a norm type * logits matching atol1e-2 * more eval fixes * more changes * xvla works on libero * remove seed * more refactoring * more fixes * more changes * more changes * more fixes * migrate policy revert * major pre-commit cleanup * renaming * revert to self.transformer * refactor * new changes * clean * update libero * more changes * make it work * more changes: * remove imagenet dependency * style * more * more refactor * remove proprio * add loss * more * more * add freeze/unfreeze options * add testing * upgrade transformers version * update testing * add installation * remove .sh file * fix testing * silent linter in xvlatest * fix failing test * upgrade test, fix failing * fix testing * more fixes to testing * require cuda in tests * temp check * add xvla docs * fix styling * update libero doc * remove timm dep * add different dtype support * remove timm skip * remove white lines * Enhance X-VLA finetuning documentation with optimizer details (#2537) Added detailed instructions for implementing a custom optimizer and modifying parameter retrieval for X-VLA finetuning. Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> * fix style * iterate on review * iterate on cpilot * revert xvla dep * free up ci * test(xvla): remove main test (#2565) * Add xvla custom optim and dtype (#2567) * add custom optim * add custom optim * add auto mode * more changes * add identity to all * add auto * release * add docs * make image smaller docs * smaller image in doc * evan smaller image doc * finalize doc --------- Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
# Constants
|
||||
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
||||
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
|
||||
IMAGE_HEIGHT = 224
|
||||
IMAGE_WIDTH = 224
|
||||
NUM_VIEWS = 2 # Number of camera views
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
|
||||
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
|
||||
|
||||
# Expected values from original XVLA implementation (reference values)
|
||||
EXPECTED_ACTIONS_SHAPE = (30, 20)
|
||||
EXPECTED_ACTIONS_MEAN = 0.117606
|
||||
EXPECTED_ACTIONS_STD = 0.245411
|
||||
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
|
||||
|
||||
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Set deterministic behavior
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
|
||||
def instantiate_lerobot_xvla(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH_LEROBOT,
|
||||
) -> tuple[
|
||||
Any, # Policy
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = XVLAPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
config = XVLAConfig(
|
||||
base_model_path=model_path,
|
||||
n_action_steps=DUMMY_ACTION_DIM,
|
||||
chunk_size=DUMMY_ACTION_DIM,
|
||||
device=DEVICE,
|
||||
num_image_views=NUM_VIEWS,
|
||||
) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
|
||||
policy = XVLAPolicy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_xvla_pre_post_processors(
|
||||
config=policy.config,
|
||||
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
|
||||
)
|
||||
|
||||
return policy, preprocessor, postprocessor
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 1
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
|
||||
# Then convert to [0, 1] float32 range for LeRobot
|
||||
def fake_rgb(h, w):
|
||||
arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||
return t
|
||||
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.image": torch.stack(
|
||||
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
|
||||
).to(device),
|
||||
f"{OBS_IMAGES}.image2": torch.stack(
|
||||
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
|
||||
).to(device),
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
def xvla_components():
|
||||
"""Fixture to instantiate and provide all XVLA components for tests."""
|
||||
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||
print("\n[Setup] Instantiating LeRobot XVLA policy...")
|
||||
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
|
||||
print("✔️ Model loaded successfully")
|
||||
yield policy_obj, preprocessor_obj, postprocessor_obj
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def policy(xvla_components):
|
||||
"""Fixture to provide the XVLA policy for tests."""
|
||||
return xvla_components[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def preprocessor(xvla_components):
|
||||
"""Fixture to provide the XVLA preprocessor for tests."""
|
||||
return xvla_components[1]
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Preprocessor Outputs")
|
||||
print("=" * 80)
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
print("\n[LeRobot] Preprocessing...")
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
|
||||
print("\nVerifying preprocessor outputs:")
|
||||
print("-" * 80)
|
||||
|
||||
# Expected shapes from tester.txt
|
||||
expected_shapes = {
|
||||
"domain_id": (1,),
|
||||
"input_ids": (1, 50),
|
||||
"proprio": (1, 20),
|
||||
"image_mask": (1, 2),
|
||||
"image_input": (1, 2, 3, 224, 224),
|
||||
}
|
||||
|
||||
for key, expected_shape in expected_shapes.items():
|
||||
if key in lerobot_inputs:
|
||||
actual_shape = tuple(lerobot_inputs[key].shape)
|
||||
print(f"\nKey: {key}")
|
||||
print(f"Expected shape: {expected_shape}")
|
||||
print(f"Actual shape: {actual_shape}")
|
||||
|
||||
if actual_shape == expected_shape:
|
||||
print("Shape matches!")
|
||||
else:
|
||||
print("Shape mismatch!")
|
||||
|
||||
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
||||
else:
|
||||
print(f"\nKey '{key}' not found in inputs!")
|
||||
|
||||
print("\nAll preprocessor outputs have correct shapes!")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_xvla_action_generation(policy, preprocessor):
|
||||
"""Test XVLA LeRobot implementation generates expected actions."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Action Generation Against Expected Values")
|
||||
print("=" * 80)
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
|
||||
# Reset seed for inference
|
||||
torch.manual_seed(42)
|
||||
with torch.no_grad():
|
||||
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
|
||||
|
||||
print(f"LeRobot actions shape: {lerobot_actions.shape}")
|
||||
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
|
||||
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
|
||||
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
|
||||
|
||||
print("\nExpected values (from original XVLA):")
|
||||
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
|
||||
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
|
||||
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
|
||||
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
|
||||
|
||||
print("\nAction Comparison:")
|
||||
print("-" * 80)
|
||||
|
||||
# Compare shapes
|
||||
actual_shape = tuple(lerobot_actions.shape)
|
||||
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
|
||||
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
|
||||
)
|
||||
print(f"✔️ Shape matches: {actual_shape}")
|
||||
|
||||
# Compare statistics
|
||||
actual_mean = lerobot_actions.mean().item()
|
||||
actual_std = lerobot_actions.std().item()
|
||||
|
||||
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
|
||||
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
|
||||
|
||||
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
|
||||
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
|
||||
|
||||
# Compare first 5 actions
|
||||
actual_first_5 = lerobot_actions[0, :5]
|
||||
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
|
||||
|
||||
print("\nFirst 5 actions comparison:")
|
||||
print(f" Actual: {actual_first_5}")
|
||||
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
|
||||
print(f" Max diff: {first_5_diff.max().item():.6e}")
|
||||
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
|
||||
|
||||
# Check with different tolerances
|
||||
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||
for tol in tolerances:
|
||||
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
||||
status = "Success" if is_close else "Failure"
|
||||
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
|
||||
|
||||
# Assert with reasonable tolerance
|
||||
tolerance = 1e-3
|
||||
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
||||
f"First 5 actions differ by more than tolerance ({tolerance})"
|
||||
)
|
||||
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||
"""Test that XVLA inference is reproducible with the same seed."""
|
||||
print("\n" + "=" * 80)
|
||||
print("Test: XVLA Inference Reproducibility")
|
||||
print("=" * 80)
|
||||
|
||||
print("\nCreating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
# First inference
|
||||
print("\n[Run 1] Running inference...")
|
||||
set_seed_all(42)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
with torch.no_grad():
|
||||
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_1 = actions_1.squeeze(0).float().cpu()
|
||||
|
||||
# Second inference with same seed
|
||||
print("\n[Run 2] Running inference with same seed...")
|
||||
set_seed_all(42)
|
||||
lerobot_observation = preprocessor(deepcopy(batch))
|
||||
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||
with torch.no_grad():
|
||||
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||
actions_2 = actions_2.squeeze(0).float().cpu()
|
||||
|
||||
print("\nComparing two runs:")
|
||||
print("-" * 80)
|
||||
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||
print("Inference is perfectly reproducible!")
|
||||
else:
|
||||
diff = torch.abs(actions_1 - actions_2)
|
||||
print("Small differences detected:")
|
||||
print(f" Max diff: {diff.max().item():.6e}")
|
||||
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||
|
||||
print("\nInference is reproducible!")
|
||||
Reference in New Issue
Block a user