Move test to specific folder

This commit is contained in:
Pepijn
2025-09-17 23:42:37 +02:00
parent 0f62c180d9
commit 2f76894ac8
4 changed files with 0 additions and 0 deletions
+187
View File
@@ -0,0 +1,187 @@
#!/usr/bin/env python
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_pi05_model_architecture():
"""Test that pi05=True creates the correct model architecture."""
# Create config
config = PI05OpenPIConfig(
max_action_dim=7,
max_state_dim=14,
dtype="float32",
)
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
assert config.tokenizer_max_length == 200, (
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
)
assert config.discrete_state_input == True, ( # noqa: E712
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
)
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = PI05OpenPIPolicy(config, dataset_stats)
# Verify pi05 model components exist
# Check that time_mlp layers exist (for AdaRMS conditioning)
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
# Check that action_time_mlp layers don't exist (pi0 only)
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
assert not hasattr(policy.model, "action_time_mlp_out"), (
"action_time_mlp_out should not exist in pi05 mode"
)
# Check that state_proj doesn't exist in pi05 mode
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
# Check AdaRMS configuration in the underlying model
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == True, ( # noqa: E712
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
)
@require_cuda
def test_pi05_forward_pass():
"""Test forward pass with"""
# Create config
config = PI05OpenPIConfig(
max_action_dim=7,
max_state_dim=14,
dtype="float32",
chunk_size=16, # Shorter chunk_size for testing
n_action_steps=16, # Shorter action steps for testing
)
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = PI05OpenPIPolicy(config, dataset_stats)
# Create test batch
batch_size = 2
device = next(policy.parameters()).device
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"task": ["Pick up the object"] * batch_size,
}
# Test forward pass
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
assert not torch.isnan(loss), "Loss is NaN"
assert loss.item() >= 0, "Loss should be non-negative"
except Exception as e:
print(f"Forward pass failed: {e}")
raise
# Test action prediction
try:
with torch.no_grad():
action = policy.select_action(batch)
print(f"Action prediction successful. Action shape: {action.shape}")
# When batch_size > 1, select_action returns (batch_size, action_dim)
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
assert not torch.isnan(action).any(), "Action contains NaN values"
except Exception as e:
print(f"Action prediction failed: {e}")
raise
+109
View File
@@ -0,0 +1,109 @@
#!/usr/bin/env python
"""Test script to verify PI0OpenPI policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_policy_instantiation():
# Create config
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = PI0OpenPIPolicy(config, dataset_stats)
# Test forward pass with dummy data
batch_size = 1
device = policy.device if hasattr(policy, "device") else "cpu"
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
try:
with torch.no_grad():
action = policy.select_action(batch)
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
raise
@require_cuda
def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="pi0_openpi",
max_action_dim=7,
max_state_dim=14,
)
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}")
except Exception as e:
print(f"Config creation failed: {e}")
raise
@@ -0,0 +1,403 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
},
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
"base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
"left_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
"right_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
},
}
class PI0BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
precision: str = "float32"
pi05: bool = False
dtype: str = "float32"
def instantiate_lerobot_pi0(from_pretrained: bool = False):
if from_pretrained:
# Load the policy first
policy = PI0OpenPIPolicy.from_pretrained(
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
)
# Then reinitialize the normalization with proper stats
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
else:
config = PI0OpenPIConfig(
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
)
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
policy.to(DEVICE)
return policy
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
config = PI0BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
class PI0Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
elif isinstance(tasks, list) and len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
else:
# Default task if not provided
tasks = ["Pick up the object"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
print("\nTEST 1: Each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed
with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch)
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
print("\nTEST 2: Both models with LeRobot preprocessing (model comparison)")
print("Creating observation for OpenPI using LeRobot's preprocessing...")
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
print("Testing OpenPI with LeRobot preprocessing...")
torch.manual_seed(42) # Set seed for reproducibility
with torch.no_grad():
openpi_actions_lerobot_preproc = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
)
print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
print("\nComparing models with same preprocessing:")
is_close_1e4 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)
is_close_1e2 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)
max_diff = torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item()
print(f"Actions close (atol=1e-4): {is_close_1e4}")
print(f"Actions close (atol=1e-2): {is_close_1e2}")
print(f"Max absolute difference: {max_diff:.6f}")
# Add assertions for pytest
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# TODO(pepijn): Remove these tests before merging
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
import os
import pytest
import torch
# Skip entire module if transformers is not available
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires HuggingFace authentication and is not meant for CI",
)
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy # noqa: E402
def create_dummy_stats(config):
"""Create dummy dataset statistics for testing."""
dummy_stats = {
"observation.state": {
"mean": torch.zeros(config.max_state_dim),
"std": torch.ones(config.max_state_dim),
},
"action": {
"mean": torch.zeros(config.max_action_dim),
"std": torch.ones(config.max_action_dim),
},
}
# Add stats for image keys if they exist
for key in config.image_features.keys():
dummy_stats[key] = {
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
}
return dummy_stats
# Test data for all 6 base models
MODEL_TEST_PARAMS = [
# PI0 models
("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy),
# PI0.5 models
("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy),
]
@pytest.mark.parametrize("model_id,model_type,policy_class", MODEL_TEST_PARAMS)
def test_all_base_models_hub_loading(model_id, model_type, policy_class):
"""Test loading and basic functionality of all 6 base models from HuggingFace Hub.
Args:
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32")
model_type: Model type ("PI0" or "PI0.5")
policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy)
"""
print(f"\n{'=' * 80}")
print(f"Testing {model_type} model: {model_id}")
print(f"{'=' * 80}")
# Load the model from HuggingFace hub
try:
policy = policy_class.from_pretrained(model_id, strict=True)
print(f"✓ Successfully loaded {model_type} model from {model_id}")
except Exception as e:
print(f"✗ Failed to load model {model_id}: {e}")
raise
# Set up input_features and output_features in the config (not set by from_pretrained)
from lerobot.configs.types import FeatureType, PolicyFeature
policy.config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(policy.config.max_state_dim,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
"observation.images.left_wrist_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
"observation.images.right_wrist_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
policy.config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(policy.config.max_action_dim,),
),
}
# Get model info
device = next(policy.parameters()).device
print("\nModel configuration:")
print(f" - Model ID: {model_id}")
print(f" - Model type: {model_type}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.max_action_dim}")
print(f" - State dimension: {policy.config.max_state_dim}")
print(f" - Chunk size: {policy.config.chunk_size}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
print(f" - Device: {device}")
print(f" - Dtype: {next(policy.parameters()).dtype}")
# Verify model-specific architecture
if model_type == "PI0.5":
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
# Verify PI0.5 specific features
assert hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0.5 should have time_mlp_in"
assert hasattr(policy.model, "time_mlp_out"), f"{model_id}: PI0.5 should have time_mlp_out"
assert not hasattr(policy.model, "state_proj"), f"{model_id}: PI0.5 should not have state_proj"
assert not hasattr(policy.model, "action_time_mlp_in"), (
f"{model_id}: PI0.5 should not have action_time_mlp_in"
)
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == True, f"{model_id}: PI0.5 expert should use AdaRMS" # noqa: E712
print(" ✓ PI0.5 architecture verified")
else:
# Verify PI0 specific features
assert hasattr(policy.model, "action_time_mlp_in"), f"{model_id}: PI0 should have action_time_mlp_in"
assert hasattr(policy.model, "action_time_mlp_out"), (
f"{model_id}: PI0 should have action_time_mlp_out"
)
assert hasattr(policy.model, "state_proj"), f"{model_id}: PI0 should have state_proj"
assert not hasattr(policy.model, "time_mlp_in"), f"{model_id}: PI0 should not have time_mlp_in"
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == False, f"{model_id}: PI0 expert should not use AdaRMS" # noqa: E712
print(" ✓ PI0 architecture verified")
# Create dummy stats for testing
dummy_stats = create_dummy_stats(policy.config)
for key, stats in dummy_stats.items():
dummy_stats[key] = {
"mean": stats["mean"].to(device),
"std": stats["std"].to(device),
}
# Initialize normalization layers with dummy stats
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
)
# Create test batch
batch_size = 1
batch = {
"observation.state": torch.randn(
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
),
"action": torch.randn(
batch_size,
policy.config.chunk_size,
policy.config.max_action_dim,
dtype=torch.float32,
device=device,
),
"task": ["Pick up the object"] * batch_size,
}
# Add images based on config
for key in policy.config.image_features.keys():
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
# Test forward pass
print(f"\nTesting forward pass for {model_id}...")
try:
policy.train()
loss, loss_dict = policy.forward(batch)
assert not torch.isnan(loss), f"{model_id}: Forward pass produced NaN loss"
assert loss.item() >= 0, f"{model_id}: Loss should be non-negative"
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"✗ Forward pass failed for {model_id}: {e}")
raise
# Test action prediction
print(f"Testing action prediction for {model_id}...")
try:
policy.eval()
with torch.no_grad():
action = policy.select_action(batch)
expected_shape = (batch_size, policy.config.max_action_dim)
assert action.shape == expected_shape, (
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
)
assert not torch.isnan(action).any(), f"{model_id}: Action contains NaN values"
print(f"✓ Action prediction successful - Shape: {action.shape}")
except Exception as e:
print(f"✗ Action prediction failed for {model_id}: {e}")
raise
print(f"All tests passed for {model_id}!")