mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests
This commit is contained in:
@@ -1,15 +1,20 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy."""
|
"""Test script to verify PI0.5 (pi05) support in PI0OpenPI policy, only meant to be run locally!"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Skip entire module if transformers is not available
|
# Skip this entire module in CI
|
||||||
pytest.importorskip("transformers")
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
|
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy # noqa: E402
|
||||||
from tests.utils import require_cuda
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
"""Test script to verify PI0OpenPI policy integration with LeRobot."""
|
"""Test script to verify PI0OpenPI policy integration with LeRobot, only meant to be run locally!"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Skip entire module if transformers is not available
|
# Skip this entire module in CI
|
||||||
pytest.importorskip("transformers")
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
|
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
|
||||||
from tests.utils import require_cuda
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
|||||||
@@ -45,208 +45,6 @@ def create_dummy_stats(config):
|
|||||||
return dummy_stats
|
return dummy_stats
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_hub_loading():
|
|
||||||
"""Test loading PI0 model from HuggingFace hub."""
|
|
||||||
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_hub_loading():
|
|
||||||
"""Test loading PI0.5 model from HuggingFace hub."""
|
|
||||||
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
|
||||||
|
|
||||||
|
|
||||||
def _test_hub_loading(model_id, model_name):
|
|
||||||
"""Internal helper function for testing hub loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: HuggingFace model ID to load
|
|
||||||
model_name: Display name for the model (e.g., "PI0", "PI0.5")
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"{model_name} OpenPI HuggingFace Hub Loading Test")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
print(f"\nLoading model from: {model_id}")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load the model from HuggingFace hub with strict mode
|
|
||||||
if model_name == "PI0.5":
|
|
||||||
policy = PI05OpenPIPolicy.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
strict=True, # Ensure all weights are loaded correctly,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
policy = PI0OpenPIPolicy.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
strict=True, # Ensure all weights are loaded correctly,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("✓ Model loaded successfully from HuggingFace hub")
|
|
||||||
|
|
||||||
# Inject dummy stats since they aren't loaded from the hub
|
|
||||||
print("Creating dummy dataset stats for testing...")
|
|
||||||
device = next(policy.parameters()).device
|
|
||||||
dummy_stats = create_dummy_stats(policy.config)
|
|
||||||
|
|
||||||
# Move dummy stats to device
|
|
||||||
for key, stats in dummy_stats.items():
|
|
||||||
dummy_stats[key] = {
|
|
||||||
"mean": stats["mean"].to(device),
|
|
||||||
"std": stats["std"].to(device),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize normalization layers with dummy stats if they have NaN/inf values
|
|
||||||
print("✓ Dummy stats created and moved to device")
|
|
||||||
|
|
||||||
# Get model info
|
|
||||||
print("\nModel configuration:")
|
|
||||||
print(f" - Model type: {model_name}")
|
|
||||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
|
||||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
|
||||||
print(f" - Action dimension: {policy.config.max_action_dim}")
|
|
||||||
print(f" - State dimension: {policy.config.max_state_dim}")
|
|
||||||
print(f" - Chunk_size: {policy.config.chunk_size}")
|
|
||||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
|
||||||
if model_name == "PI0.5":
|
|
||||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
|
||||||
print(f" - Device: {device}")
|
|
||||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
|
||||||
|
|
||||||
# Check model-specific features
|
|
||||||
if model_name == "PI0.5":
|
|
||||||
print("\nPI0.5 specific features:")
|
|
||||||
print(f" - Has time_mlp layers: {hasattr(policy.model, 'time_mlp_in')}")
|
|
||||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be False)")
|
|
||||||
print(f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms}")
|
|
||||||
|
|
||||||
# Verify PI0.5 architecture
|
|
||||||
assert hasattr(policy.model, "time_mlp_in"), "PI0.5 should have time_mlp_in"
|
|
||||||
assert hasattr(policy.model, "time_mlp_out"), "PI0.5 should have time_mlp_out"
|
|
||||||
assert not hasattr(policy.model, "state_proj"), "PI0.5 should not have state_proj"
|
|
||||||
assert not hasattr(policy.model, "action_time_mlp_in"), "PI0.5 should not have action_time_mlp_in"
|
|
||||||
print(" ✓ PI0.5 architecture verified")
|
|
||||||
else:
|
|
||||||
print("\nPI0 specific features:")
|
|
||||||
print(f" - Has action_time_mlp layers: {hasattr(policy.model, 'action_time_mlp_in')}")
|
|
||||||
print(f" - Has state_proj: {hasattr(policy.model, 'state_proj')} (should be True)")
|
|
||||||
print(
|
|
||||||
f" - Uses AdaRMS: {policy.model.paligemma_with_expert.gemma_expert.config.use_adarms} (should be False)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify PI0 architecture
|
|
||||||
assert hasattr(policy.model, "action_time_mlp_in"), "PI0 should have action_time_mlp_in"
|
|
||||||
assert hasattr(policy.model, "action_time_mlp_out"), "PI0 should have action_time_mlp_out"
|
|
||||||
assert hasattr(policy.model, "state_proj"), "PI0 should have state_proj"
|
|
||||||
assert not hasattr(policy.model, "time_mlp_in"), "PI0 should not have time_mlp_in"
|
|
||||||
print(" ✓ PI0 architecture verified")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
|
||||||
print("Testing forward pass with loaded model...")
|
|
||||||
|
|
||||||
# Create dummy batch for testing
|
|
||||||
batch_size = 1
|
|
||||||
|
|
||||||
# Check if normalization layers have invalid stats and replace with dummy stats if needed
|
|
||||||
try:
|
|
||||||
# Check if the normalize_inputs has valid stats
|
|
||||||
if hasattr(policy.normalize_inputs, "stats"):
|
|
||||||
obs_state_mean = policy.normalize_inputs.stats.get("observation.state", {}).get("mean")
|
|
||||||
if obs_state_mean is not None and (
|
|
||||||
torch.isinf(obs_state_mean).any() or torch.isnan(obs_state_mean).any()
|
|
||||||
):
|
|
||||||
print("⚠️ Found invalid normalization stats, replacing with dummy stats...")
|
|
||||||
|
|
||||||
# Replace with dummy stats
|
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
|
||||||
|
|
||||||
policy.normalize_inputs = Normalize(
|
|
||||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
policy.normalize_targets = Normalize(
|
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
policy.unnormalize_outputs = Unnormalize(
|
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
print("✓ Normalization layers updated with dummy stats")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Error checking normalization stats, creating new ones: {e}")
|
|
||||||
# Fallback: create new normalization layers
|
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
|
||||||
|
|
||||||
policy.normalize_inputs = Normalize(
|
|
||||||
policy.config.input_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
policy.normalize_targets = Normalize(
|
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
policy.unnormalize_outputs = Unnormalize(
|
|
||||||
policy.config.output_features, policy.config.normalization_mapping, dummy_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create test batch
|
|
||||||
batch = {
|
|
||||||
"observation.state": torch.randn(
|
|
||||||
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
|
|
||||||
),
|
|
||||||
"action": torch.randn(
|
|
||||||
batch_size,
|
|
||||||
policy.config.chunk_size,
|
|
||||||
policy.config.max_action_dim,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
"task": ["Pick up the object"] * batch_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add images if they're in the config
|
|
||||||
for key in policy.config.image_features.keys():
|
|
||||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test forward pass
|
|
||||||
policy.train() # Set to training mode for forward pass with loss
|
|
||||||
loss, loss_dict = policy.forward(batch)
|
|
||||||
print("✓ Forward pass successful")
|
|
||||||
print(f" - Loss: {loss_dict['loss']:.4f}")
|
|
||||||
print(f" - Loss shape: {loss.shape if hasattr(loss, 'shape') else 'scalar'}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Forward pass failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
|
||||||
print("Testing inference with loaded model...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test action prediction
|
|
||||||
policy.eval() # Set to evaluation mode for inference
|
|
||||||
with torch.no_grad():
|
|
||||||
action = policy.select_action(batch)
|
|
||||||
print("✓ Action prediction successful")
|
|
||||||
print(f" - Action shape: {action.shape}")
|
|
||||||
print(f" - Action range: [{action.min().item():.3f}, {action.max().item():.3f}]")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Action prediction failed: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"✓ All tests passed for {model_name}!")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
# Test data for all 6 base models
|
# Test data for all 6 base models
|
||||||
MODEL_TEST_PARAMS = [
|
MODEL_TEST_PARAMS = [
|
||||||
# PI0 models
|
# PI0 models
|
||||||
@@ -281,6 +79,35 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
|||||||
print(f"✗ Failed to load model {model_id}: {e}")
|
print(f"✗ Failed to load model {model_id}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Set up input_features and output_features in the config (not set by from_pretrained)
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
policy.config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(policy.config.max_state_dim,),
|
||||||
|
),
|
||||||
|
"observation.images.base_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
"observation.images.left_wrist_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
"observation.images.right_wrist_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
policy.config.output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(policy.config.max_action_dim,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Get model info
|
# Get model info
|
||||||
device = next(policy.parameters()).device
|
device = next(policy.parameters()).device
|
||||||
print("\nModel configuration:")
|
print("\nModel configuration:")
|
||||||
|
|||||||
Reference in New Issue
Block a user