mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
put tests in test folder
This commit is contained in:
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi05_model_architecture():
|
||||
"""Test that pi05=True creates the correct model architecture."""
|
||||
print("Testing PI0.5 model architecture...")
|
||||
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# Verify tokenizer max length is set correctly
|
||||
assert config.tokenizer_max_length == 200, (
|
||||
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
||||
)
|
||||
print(f"✓ Tokenizer max length correctly set to {config.tokenizer_max_length}")
|
||||
|
||||
# Verify discrete_state_input defaults to pi05
|
||||
assert config.discrete_state_input == True, ( # noqa: E712
|
||||
f"Expected discrete_state_input=True for pi05, got {config.discrete_state_input}"
|
||||
)
|
||||
print(f"✓ discrete_state_input correctly defaults to pi05 value: {config.discrete_state_input}")
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Verify pi05 model components exist
|
||||
|
||||
# Check that time_mlp layers exist (for AdaRMS conditioning)
|
||||
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
|
||||
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
|
||||
print("✓ Time MLP layers present for AdaRMS conditioning")
|
||||
|
||||
# Check that action_time_mlp layers don't exist (pi0 only)
|
||||
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
|
||||
assert not hasattr(policy.model, "action_time_mlp_out"), (
|
||||
"action_time_mlp_out should not exist in pi05 mode"
|
||||
)
|
||||
print("✓ Action-time MLP layers correctly absent")
|
||||
|
||||
# Check that state_proj doesn't exist in pi05 mode
|
||||
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
|
||||
print("✓ State projection layer correctly absent")
|
||||
|
||||
# Check AdaRMS configuration in the underlying model
|
||||
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
|
||||
assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712
|
||||
|
||||
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
|
||||
assert adarms_expert_config == True, ( # noqa: E712
|
||||
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
|
||||
)
|
||||
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi05_forward_pass():
|
||||
"""Test forward pass with"""
|
||||
print("\nTesting PI0.5 forward pass...")
|
||||
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
chunk_size=16, # Shorter chunk_size for testing
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
)
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Create test batch
|
||||
batch_size = 2
|
||||
device = next(policy.parameters()).device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Test forward pass
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
assert not torch.isnan(loss), "Loss is NaN"
|
||||
assert loss.item() >= 0, "Loss should be non-negative"
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test action prediction
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||
# When batch_size > 1, select_action returns (batch_size, action_dim)
|
||||
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
|
||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi0_vs_pi05_differences():
|
||||
"""Test key differences between pi0 and pi05 modes."""
|
||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||
|
||||
# Create both configurations
|
||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
}
|
||||
|
||||
# Create both models
|
||||
policy_pi0 = PI0OpenPIPolicy(config_pi0, dataset_stats)
|
||||
policy_pi05 = PI05OpenPIPolicy(config_pi05, dataset_stats)
|
||||
|
||||
print("\nPI0 Model:")
|
||||
print(f" - Tokenizer max length: {config_pi0.tokenizer_max_length}")
|
||||
print(f" - Has state_proj: {hasattr(policy_pi0.model, 'state_proj')}")
|
||||
print(f" - Has action_time_mlp: {hasattr(policy_pi0.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has time_mlp: {hasattr(policy_pi0.model, 'time_mlp_in')}")
|
||||
print(f" - Uses AdaRMS: {policy_pi0.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
print("\nPI0.5 Model:")
|
||||
print(f" - Tokenizer max length: {config_pi05.tokenizer_max_length}")
|
||||
print(f" - discrete_state_input: {config_pi05.discrete_state_input}")
|
||||
print(f" - Has state_proj: {hasattr(policy_pi05.model, 'state_proj')}")
|
||||
print(f" - Has action_time_mlp: {hasattr(policy_pi05.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has time_mlp: {hasattr(policy_pi05.model, 'time_mlp_in')}")
|
||||
print(f" - Uses AdaRMS: {policy_pi05.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
# Count parameters
|
||||
pi0_params = sum(p.numel() for p in policy_pi0.parameters())
|
||||
pi05_params = sum(p.numel() for p in policy_pi05.parameters())
|
||||
|
||||
print("\nParameter counts:")
|
||||
print(f" - PI0: {pi0_params:,}")
|
||||
print(f" - PI0.5: {pi05_params:,}")
|
||||
print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)")
|
||||
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_policy_instantiation():
|
||||
"""Test basic policy instantiation."""
|
||||
print("Testing PI0OpenPI policy instantiation...")
|
||||
|
||||
# Create config
|
||||
config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
print(f"Policy created successfully: {policy.name}")
|
||||
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
print("\nTesting forward pass...")
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
return False
|
||||
|
||||
print("\nTesting action prediction...")
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
print("\nTesting config creation through factory...")
|
||||
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="pi0_openpi",
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
)
|
||||
print("✓ Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" Action expert variant: {config.action_expert_variant}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Config creation failed: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,396 @@
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation."""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
"base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
"left_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
"right_wrist_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(from_pretrained: bool = False):
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
|
||||
)
|
||||
# Then reinitialize the normalization with proper stats
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
else:
|
||||
config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
|
||||
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\n=== TEST 1: Each model with its own preprocessing ===")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch)
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
|
||||
print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===")
|
||||
print("Creating observation for OpenPI using LeRobot's preprocessing...")
|
||||
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
|
||||
|
||||
print("Testing OpenPI with LeRobot preprocessing...")
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
with torch.no_grad():
|
||||
openpi_actions_lerobot_preproc = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
|
||||
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
|
||||
|
||||
print("\nComparing models with same preprocessing:")
|
||||
is_close_1e4 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)
|
||||
is_close_1e2 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)
|
||||
max_diff = torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item()
|
||||
|
||||
print(f"Actions close (atol=1e-4): {is_close_1e4}")
|
||||
print(f"Actions close (atol=1e-2): {is_close_1e2}")
|
||||
print(f"Max absolute difference: {max_diff:.6f}")
|
||||
|
||||
# Add assertions for pytest
|
||||
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
|
||||
|
||||
print("\n=== SUMMARY ===")
|
||||
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
|
||||
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
|
||||
print("Both tests completed successfully!")
|
||||
@@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||
from tests.utils import require_nightly_gpu
|
||||
|
||||
|
||||
def create_dummy_stats(config):
|
||||
"""Create dummy dataset statistics for testing."""
|
||||
dummy_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(config.state_dim),
|
||||
"std": torch.ones(config.state_dim),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(config.action_dim),
|
||||
"std": torch.ones(config.action_dim),
|
||||
},
|
||||
}
|
||||
|
||||
# Add stats for image keys if they exist
|
||||
for key in config.image_keys:
|
||||
dummy_stats[key] = {
|
||||
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
|
||||
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
|
||||
}
|
||||
|
||||
return dummy_stats
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi0_hub_loading():
|
||||
"""Test loading PI0 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
||||
|
||||
|
||||
@require_nightly_gpu
|
||||
def test_pi05_hub_loading():
|
||||
"""Test loading PI0.5 model from HuggingFace hub."""
|
||||
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||
|
||||
|
||||
def _test_hub_loading(model_id, model_name):
|
||||
"""Internal helper function for testing hub loading.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID to load
|
||||
model_name: Display name for the model (e.g., "PI0", "PI0.5")
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(f"{model_name} OpenPI HuggingFace Hub Loading Test")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nLoading model from: {model_id}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
# Load the model from HuggingFace hub with strict mode
|
||||
if model_name == "PI0.5":
|
||||
policy = PI05OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly,
|
||||
)
|
||||
else:
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly,
|
||||
)
|
||||
|
||||
print("✓ Model loaded successfully from HuggingFace hub")
|
||||
|
||||
# Inject dummy stats since they aren't loaded from the hub
|
||||
print("Creating dummy dataset stats for testing...")
|
||||
device = next(policy.parameters()).device
|
||||
dummy_stats = create_dummy_stats(policy.config)
|
||||
|
||||
# Move dummy stats to device
|
||||
for key, stats in dummy_stats.items():
|
||||
dummy_stats[key] = {
|
||||
"mean": stats["mean"].to(device),
|
||||
"std": stats["std"].to(device),
|
||||
}
|
||||
|
||||
# Initialize normalization layers with dummy stats if they have NaN/inf values
|
||||
print("✓ Dummy stats created and moved to device")
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - Model type: {model_name}")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Chunk_size: {policy.config.chunk_size}")
|
||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||
if model_name == "PI0.5":
|
||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||
print(f" - Device: {device}")
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
# Check model-specific features
|
||||
if model_name == "PI0.5":
|
||||
print("\nPI0.5 specific features:")
|
||||
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
|
||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")
|
||||
print(f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
||||
|
||||
# Verify PI0.5 architecture
|
||||
assert hasattr(policy.model, "time_mlp_in"), "PI0.5 should have time_mlp_in"
|
||||
assert hasattr(policy.model, "time_mlp_out"), "PI0.5 should have time_mlp_out"
|
||||
assert not hasattr(policy.model, "state_proj"), "PI0.5 should not have state_proj"
|
||||
assert not hasattr(policy.model, "action_time_mlp_in"), "PI0.5 should not have action_time_mlp_in"
|
||||
print(" ✓ PI0.5 architecture verified")
|
||||
else:
|
||||
print("\nPI0 specific features:")
|
||||
print(f" - Has action_time_mlp layers: {hasattr(policy.model, 'action_time_mlp_in')}")
|
||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be True)")
|
||||
print(
|
||||
f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms} (should be False)"
|
||||
)
|
||||
|
||||
# Verify PI0 architecture
|
||||
assert hasattr(policy.model, "action_time_mlp_in"), "PI0 should have action_time_mlp_in"
|
||||
assert hasattr(policy.model, "action_time_mlp_out"), "PI0 should have action_time_mlp_out"
|
||||
assert hasattr(policy.model, "state_proj"), "PI0 should have state_proj"
|
||||
assert not hasattr(policy.model, "time_mlp_in"), "PI0 should not have time_mlp_in"
|
||||
print(" ✓ PI0 architecture verified")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing forward pass with loaded model...")
|
||||
|
||||
# Create dummy batch for testing
|
||||
batch_size = 1
|
||||
|
||||
# Check if normalization layers have invalid stats and replace with dummy stats if needed
|
||||
try:
|
||||
# Check if the normalize_inputs has valid stats
|
||||
if hasattr(policy.normalize_inputs, "stats"):
|
||||
obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean")
|
||||
if obs_state_mean is not None and (
|
||||
torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any()
|
||||
):
|
||||
print("⚠️ Found invalid normalization stats, replacing with dummy stats...")
|
||||
|
||||
# Replace with dummy stats
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
print("✓ Normalization layers updated with dummy stats")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error checking normalization stats, creating new ones: {e}")
|
||||
# Fallback: create new normalization layers
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
||||
)
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(
|
||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
policy.config.chunk_size,
|
||||
policy.config.action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Add images if they're in the config
|
||||
for key in policy.config.image_keys:
|
||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||
|
||||
try:
|
||||
# Test forward pass
|
||||
policy.train() # Set to training mode for forward pass with loss
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print("✓ Forward pass successful")
|
||||
print(f" - Loss: {loss_dict['loss']:.4f}")
|
||||
print(f" - Loss shape: {loss.shape if hasattr(loss, 'shape') else 'scalar'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing inference with loaded model...")
|
||||
|
||||
try:
|
||||
# Test action prediction
|
||||
policy.eval() # Set to evaluation mode for inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print("✓ Action prediction successful")
|
||||
print(f" - Action shape: {action.shape}")
|
||||
print(f" - Action range: [{action.min().item():.3f}, {action.max().item():.3f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"✓ All tests passed for {model_name}!")
|
||||
print("=" * 60)
|
||||
@@ -167,6 +167,24 @@ def require_package_arg(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_nightly_gpu(func):
|
||||
"""
|
||||
Decorator that skips the test unless running in nightly environment with GPU.
|
||||
Combines GPU availability check with nightly workflow detection.
|
||||
"""
|
||||
|
||||
@require_cuda
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Check if running in nightly workflow (GitHub Actions)
|
||||
is_nightly = os.environ.get("GITHUB_WORKFLOW") == "Nightly"
|
||||
if not is_nightly:
|
||||
pytest.skip("Test only runs in nightly workflow with GPU")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_package(package_name):
|
||||
"""
|
||||
Decorator that skips the test if the specified package is not installed.
|
||||
|
||||
Reference in New Issue
Block a user