mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
put tests in test folder
This commit is contained in:
@@ -31,8 +31,6 @@
|
|||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
title: "Policies"
|
title: "Policies"
|
||||||
- sections:
|
- sections:
|
||||||
- local: hope_jr
|
|
||||||
title: Hope Jr
|
|
||||||
- local: so101
|
- local: so101
|
||||||
title: SO-101
|
title: SO-101
|
||||||
- local: so100
|
- local: so100
|
||||||
@@ -41,6 +39,8 @@
|
|||||||
title: Koch v1.1
|
title: Koch v1.1
|
||||||
- local: lekiwi
|
- local: lekiwi
|
||||||
title: LeKiwi
|
title: LeKiwi
|
||||||
|
- local: hope_jr
|
||||||
|
title: Hope Jr
|
||||||
- local: reachy2
|
- local: reachy2
|
||||||
title: Reachy 2
|
title: Reachy 2
|
||||||
title: "Robots"
|
title: "Robots"
|
||||||
|
|||||||
@@ -1,193 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
"""Script to create and push a PI0OpenPI model to HuggingFace hub with proper config format."""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import HfApi, create_repo
|
|
||||||
|
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def create_and_push_model(
|
|
||||||
repo_id: str,
|
|
||||||
private: bool = False,
|
|
||||||
token: str = None,
|
|
||||||
):
|
|
||||||
"""Create a PI0OpenPI model with proper config and push to HuggingFace hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_id: HuggingFace repository ID (e.g., "username/model-name")
|
|
||||||
private: Whether to create a private repository
|
|
||||||
token: HuggingFace API token (optional, will use cached token if not provided)
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("PI0OpenPI Model Hub Upload")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Create configuration
|
|
||||||
print("\nCreating PI0OpenPI configuration...")
|
|
||||||
config = PI0OpenPIConfig(
|
|
||||||
# Model architecture
|
|
||||||
paligemma_variant="gemma_2b",
|
|
||||||
action_expert_variant="gemma_300m",
|
|
||||||
pi05=False, # Use PI0 (not PI0.5)
|
|
||||||
dtype="float32", # Use float32 for compatibility
|
|
||||||
# Input/output dimensions
|
|
||||||
action_dim=32, # see openpi `Pi0Config`
|
|
||||||
state_dim=32,
|
|
||||||
chunk_size=50,
|
|
||||||
n_action_steps=50,
|
|
||||||
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
|
||||||
image_keys=(
|
|
||||||
"observation.images.base_0_rgb",
|
|
||||||
"observation.images.left_wrist_0_rgb",
|
|
||||||
"observation.images.right_wrist_0_rgb",
|
|
||||||
),
|
|
||||||
# Training settings
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
compile_model=False,
|
|
||||||
device=None, # Auto-detect
|
|
||||||
# Tokenizer settings
|
|
||||||
tokenizer_max_length=48, # see openpi `__post_init__`, use pi0=48 and pi05=200
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" - Config type: {config.__class__.__name__}")
|
|
||||||
print(f" - PaliGemma variant: {config.paligemma_variant}")
|
|
||||||
print(f" - Action expert variant: {config.action_expert_variant}")
|
|
||||||
print(f" - Action dim: {config.action_dim}")
|
|
||||||
print(f" - State dim: {config.state_dim}")
|
|
||||||
|
|
||||||
# Create dummy dataset stats for normalization
|
|
||||||
print("\nCreating dataset statistics...")
|
|
||||||
dataset_stats = {
|
|
||||||
"observation.state": {
|
|
||||||
"mean": torch.zeros(config.state_dim),
|
|
||||||
"std": torch.ones(config.state_dim),
|
|
||||||
"min": torch.full((config.state_dim,), -5.0),
|
|
||||||
"max": torch.full((config.state_dim,), 5.0),
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"mean": torch.zeros(config.action_dim),
|
|
||||||
"std": torch.ones(config.action_dim),
|
|
||||||
"min": torch.full((config.action_dim,), -1.0),
|
|
||||||
"max": torch.full((config.action_dim,), 1.0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add image stats
|
|
||||||
for key in config.image_keys:
|
|
||||||
dataset_stats[key] = {
|
|
||||||
"mean": torch.tensor([0.485, 0.456, 0.406]), # TODO(pepijn): fix this, now its ImageNet mean
|
|
||||||
"std": torch.tensor([0.229, 0.224, 0.225]), # TODO(pepijn): fix this, now its ImageNet std
|
|
||||||
"min": torch.tensor([0.0, 0.0, 0.0]),
|
|
||||||
"max": torch.tensor([1.0, 1.0, 1.0]),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the policy
|
|
||||||
print("\nInitializing PI0OpenPI policy...")
|
|
||||||
print(" (This may take a moment as it loads the tokenizer and initializes the model)")
|
|
||||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
|
||||||
|
|
||||||
# Initialize with small random weights (optional - for testing)
|
|
||||||
# Note: In practice, you would load your trained weights here
|
|
||||||
print("\nInitializing model weights...")
|
|
||||||
for name, param in policy.named_parameters():
|
|
||||||
if "weight" in name:
|
|
||||||
if "norm" in name.lower() or "layernorm" in name.lower():
|
|
||||||
torch.nn.init.ones_(param)
|
|
||||||
elif len(param.shape) >= 2:
|
|
||||||
torch.nn.init.xavier_uniform_(param, gain=0.01)
|
|
||||||
else:
|
|
||||||
torch.nn.init.normal_(param, mean=0.0, std=0.01)
|
|
||||||
elif "bias" in name:
|
|
||||||
torch.nn.init.zeros_(param)
|
|
||||||
|
|
||||||
print(f" - Total parameters: {sum(p.numel() for p in policy.parameters()):,}")
|
|
||||||
print(f" - Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}")
|
|
||||||
|
|
||||||
# Create temporary directory for saving
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
save_path = Path(tmpdir) / "model"
|
|
||||||
save_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\nSaving model to temporary directory: {save_path}")
|
|
||||||
|
|
||||||
# Save the model using LeRobot's save_pretrained method
|
|
||||||
# This ensures the config is saved in the correct format
|
|
||||||
policy.save_pretrained(save_path)
|
|
||||||
|
|
||||||
# List saved files
|
|
||||||
saved_files = list(save_path.glob("*"))
|
|
||||||
print("\nSaved files:")
|
|
||||||
for file in saved_files:
|
|
||||||
size = file.stat().st_size
|
|
||||||
print(f" - {file.name}: {size:,} bytes")
|
|
||||||
|
|
||||||
# Create or get repository
|
|
||||||
print(f"\nCreating/accessing repository: {repo_id}")
|
|
||||||
api = HfApi(token=token)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create repo if it doesn't exist
|
|
||||||
create_repo(
|
|
||||||
repo_id,
|
|
||||||
private=private,
|
|
||||||
token=token,
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
print(f" ✓ Repository ready: https://huggingface.co/{repo_id}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ⚠️ Note: {e}")
|
|
||||||
|
|
||||||
# Upload to hub
|
|
||||||
print("\nUploading to HuggingFace hub...")
|
|
||||||
api.upload_folder(
|
|
||||||
folder_path=str(save_path),
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="model",
|
|
||||||
token=token,
|
|
||||||
commit_message="Upload PI0OpenPI model with proper LeRobot config format",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✓ Process complete!")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
return policy
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Push PI0OpenPI model to HuggingFace hub")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default="test-user/pi0-openpi-test",
|
|
||||||
help="HuggingFace repository ID (e.g., 'username/model-name')",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--private",
|
|
||||||
action="store_true",
|
|
||||||
help="Create a private repository",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--token",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="HuggingFace API token (optional, uses cached token if not provided)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Run the upload
|
|
||||||
create_and_push_model(
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
private=args.private,
|
|
||||||
token=args.token,
|
|
||||||
)
|
|
||||||
@@ -7,8 +7,10 @@ import torch
|
|||||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||||
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy
|
||||||
|
from tests.utils import require_nightly_gpu
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
def test_pi05_model_architecture():
|
def test_pi05_model_architecture():
|
||||||
"""Test that pi05=True creates the correct model architecture."""
|
"""Test that pi05=True creates the correct model architecture."""
|
||||||
print("Testing PI0.5 model architecture...")
|
print("Testing PI0.5 model architecture...")
|
||||||
@@ -75,9 +77,8 @@ def test_pi05_model_architecture():
|
|||||||
)
|
)
|
||||||
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
|
print("✓ AdaRMS correctly configured: PaliGemma=False, Expert=True")
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
def test_pi05_forward_pass():
|
def test_pi05_forward_pass():
|
||||||
"""Test forward pass with"""
|
"""Test forward pass with"""
|
||||||
print("\nTesting PI0.5 forward pass...")
|
print("\nTesting PI0.5 forward pass...")
|
||||||
@@ -126,7 +127,7 @@ def test_pi05_forward_pass():
|
|||||||
assert loss.item() >= 0, "Loss should be non-negative"
|
assert loss.item() >= 0, "Loss should be non-negative"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Forward pass failed: {e}")
|
print(f"✗ Forward pass failed: {e}")
|
||||||
return False
|
raise
|
||||||
|
|
||||||
# Test action prediction
|
# Test action prediction
|
||||||
try:
|
try:
|
||||||
@@ -138,11 +139,10 @@ def test_pi05_forward_pass():
|
|||||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Action prediction failed: {e}")
|
print(f"✗ Action prediction failed: {e}")
|
||||||
return False
|
raise
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
def test_pi0_vs_pi05_differences():
|
def test_pi0_vs_pi05_differences():
|
||||||
"""Test key differences between pi0 and pi05 modes."""
|
"""Test key differences between pi0 and pi05 modes."""
|
||||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||||
@@ -183,44 +183,3 @@ def test_pi0_vs_pi05_differences():
|
|||||||
print(f" - PI0: {pi0_params:,}")
|
print(f" - PI0: {pi0_params:,}")
|
||||||
print(f" - PI0.5: {pi05_params:,}")
|
print(f" - PI0.5: {pi05_params:,}")
|
||||||
print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)")
|
print(f" - Difference: {pi0_params - pi05_params:,} (PI0.5 has fewer params due to no state embedding)")
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all PI0.5 tests."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("PI0.5 Support Test Suite")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
tests = [
|
|
||||||
("Model Architecture", test_pi05_model_architecture),
|
|
||||||
("Forward Pass", test_pi05_forward_pass),
|
|
||||||
("PI0 vs PI0.5 Comparison", test_pi0_vs_pi05_differences),
|
|
||||||
]
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
for test_name, test_func in tests:
|
|
||||||
print(f"\n[{test_name}]")
|
|
||||||
print("-" * 40)
|
|
||||||
try:
|
|
||||||
if not test_func():
|
|
||||||
all_passed = False
|
|
||||||
print(f"✗ {test_name} failed")
|
|
||||||
except Exception as e:
|
|
||||||
all_passed = False
|
|
||||||
print(f"✗ {test_name} failed with exception: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("✅ All PI0.5 tests passed!")
|
|
||||||
else:
|
|
||||||
print("❌ Some tests failed.")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -6,8 +6,10 @@ import torch
|
|||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config
|
from lerobot.policies.factory import make_policy_config
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||||
|
from tests.utils import require_nightly_gpu
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
"""Test basic policy instantiation."""
|
"""Test basic policy instantiation."""
|
||||||
print("Testing PI0OpenPI policy instantiation...")
|
print("Testing PI0OpenPI policy instantiation...")
|
||||||
@@ -63,6 +65,7 @@ def test_policy_instantiation():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
print("\nTesting config creation through factory...")
|
print("\nTesting config creation through factory...")
|
||||||
@@ -81,29 +84,3 @@ def test_config_creation():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Config creation failed: {e}")
|
print(f"✗ Config creation failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("PI0OpenPI Policy Integration Test")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Test config creation
|
|
||||||
config_test = test_config_creation()
|
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
|
||||||
|
|
||||||
# Test policy instantiation
|
|
||||||
policy_test = test_policy_instantiation()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if config_test and policy_test:
|
|
||||||
print("✓ All tests passed!")
|
|
||||||
else:
|
|
||||||
print("✗ Some tests failed.")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -10,6 +10,7 @@ from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||||
|
from tests.utils import require_nightly_gpu
|
||||||
|
|
||||||
DUMMY_ACTION_DIM = 32
|
DUMMY_ACTION_DIM = 32
|
||||||
DUMMY_STATE_DIM = 32
|
DUMMY_STATE_DIM = 32
|
||||||
@@ -311,7 +312,9 @@ def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
@require_nightly_gpu
|
||||||
|
def test_pi0_original_vs_lerobot():
|
||||||
|
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||||
print("Initializing models...")
|
print("Initializing models...")
|
||||||
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
|
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
|
||||||
original_pi0 = instantiate_original_pi0(
|
original_pi0 = instantiate_original_pi0(
|
||||||
@@ -376,21 +379,18 @@ def main():
|
|||||||
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
|
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
|
||||||
|
|
||||||
print("\nComparing models with same preprocessing:")
|
print("\nComparing models with same preprocessing:")
|
||||||
print(
|
is_close_1e4 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)
|
||||||
f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)}"
|
is_close_1e2 = torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)
|
||||||
)
|
max_diff = torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item()
|
||||||
print(
|
|
||||||
f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)}"
|
print(f"Actions close (atol=1e-4): {is_close_1e4}")
|
||||||
)
|
print(f"Actions close (atol=1e-2): {is_close_1e2}")
|
||||||
print(
|
print(f"Max absolute difference: {max_diff:.6f}")
|
||||||
f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item():.6f}"
|
|
||||||
)
|
# Add assertions for pytest
|
||||||
|
assert is_close_1e2, f"Models should produce similar results (atol=1e-2), max diff: {max_diff}"
|
||||||
|
|
||||||
print("\n=== SUMMARY ===")
|
print("\n=== SUMMARY ===")
|
||||||
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
|
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
|
||||||
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
|
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
|
||||||
print("Both tests completed successfully!")
|
print("Both tests completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||||
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
|
||||||
|
from tests.utils import require_nightly_gpu
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_stats(config):
|
def create_dummy_stats(config):
|
||||||
@@ -31,8 +32,20 @@ def create_dummy_stats(config):
|
|||||||
return dummy_stats
|
return dummy_stats
|
||||||
|
|
||||||
|
|
||||||
def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
@require_nightly_gpu
|
||||||
"""Test loading model from HuggingFace hub.
|
def test_pi0_hub_loading():
|
||||||
|
"""Test loading PI0 model from HuggingFace hub."""
|
||||||
|
_test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
||||||
|
|
||||||
|
|
||||||
|
@require_nightly_gpu
|
||||||
|
def test_pi05_hub_loading():
|
||||||
|
"""Test loading PI0.5 model from HuggingFace hub."""
|
||||||
|
_test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_hub_loading(model_id, model_name):
|
||||||
|
"""Internal helper function for testing hub loading.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: HuggingFace model ID to load
|
model_id: HuggingFace model ID to load
|
||||||
@@ -119,7 +132,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed to load model: {e}")
|
print(f"✗ Failed to load model: {e}")
|
||||||
return False
|
raise
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
print("\n" + "-" * 60)
|
||||||
print("Testing forward pass with loaded model...")
|
print("Testing forward pass with loaded model...")
|
||||||
@@ -197,7 +210,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
raise
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
print("\n" + "-" * 60)
|
||||||
print("Testing inference with loaded model...")
|
print("Testing inference with loaded model...")
|
||||||
@@ -216,58 +229,8 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
raise
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print(f"✓ All tests passed for {model_name}!")
|
print(f"✓ All tests passed for {model_name}!")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run tests for both PI0 and PI0.5 models."""
|
|
||||||
print("\n")
|
|
||||||
print("╔" + "═" * 58 + "╗")
|
|
||||||
print("║" + " PI0 & PI0.5 HuggingFace Hub Loading Test Suite ".center(58) + "║")
|
|
||||||
print("╚" + "═" * 58 + "╝")
|
|
||||||
print()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# Test PI0 model
|
|
||||||
print("\n[Test 1/2] Testing PI0 model...")
|
|
||||||
print("─" * 60)
|
|
||||||
pi0_success = test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0")
|
|
||||||
results.append(("PI0", pi0_success))
|
|
||||||
|
|
||||||
# Test PI0.5 model
|
|
||||||
print("\n\n[Test 2/2] Testing PI0.5 model...")
|
|
||||||
print("─" * 60)
|
|
||||||
pi05_success = test_hub_loading(model_id="pepijn223/pi05_base_fp32", model_name="PI0.5")
|
|
||||||
results.append(("PI0.5", pi05_success))
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n\n")
|
|
||||||
print("╔" + "═" * 58 + "╗")
|
|
||||||
print("║" + " TEST SUMMARY ".center(58) + "║")
|
|
||||||
print("╚" + "═" * 58 + "╝")
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
for model_name, success in results:
|
|
||||||
status = "✅ PASSED" if success else "❌ FAILED"
|
|
||||||
print(f" {model_name:10} : {status}")
|
|
||||||
if not success:
|
|
||||||
all_passed = False
|
|
||||||
|
|
||||||
print()
|
|
||||||
if all_passed:
|
|
||||||
print("🎉 All models loaded and tested successfully!")
|
|
||||||
else:
|
|
||||||
print("⚠️ Some tests failed. Check the output above for details.")
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
exit(0 if success else 1)
|
|
||||||
@@ -167,6 +167,24 @@ def require_package_arg(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_nightly_gpu(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test unless running in nightly environment with GPU.
|
||||||
|
Combines GPU availability check with nightly workflow detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@require_cuda
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Check if running in nightly workflow (GitHub Actions)
|
||||||
|
is_nightly = os.environ.get("GITHUB_WORKFLOW") == "Nightly"
|
||||||
|
if not is_nightly:
|
||||||
|
pytest.skip("Test only runs in nightly workflow with GPU")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_package(package_name):
|
def require_package(package_name):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the specified package is not installed.
|
Decorator that skips the test if the specified package is not installed.
|
||||||
|
|||||||
Reference in New Issue
Block a user