cleanup tests

This commit is contained in:
Pepijn
2025-09-17 17:35:07 +02:00
parent bc10fc7696
commit 64974c38c2
4 changed files with 37 additions and 97 deletions
+6 -14
View File
@@ -15,9 +15,6 @@ from tests.utils import require_cuda
@require_cuda
def test_policy_instantiation():
"""Test basic policy instantiation."""
print("Testing PI0OpenPI policy instantiation...")
# Create config
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
@@ -35,7 +32,6 @@ def test_policy_instantiation():
# Instantiate policy
policy = PI0OpenPIPolicy(config, dataset_stats)
print(f"Policy created successfully: {policy.name}")
# Test forward pass with dummy data
batch_size = 1
@@ -49,39 +45,35 @@ def test_policy_instantiation():
"task": ["Pick up the object"] * batch_size,
}
print("\nTesting forward pass...")
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
print(f"Forward pass failed: {e}")
raise
print("\nTesting action prediction...")
try:
with torch.no_grad():
action = policy.select_action(batch)
print(f"Action prediction successful. Action shape: {action.shape}")
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
print(f"Action prediction failed: {e}")
raise
@require_cuda
def test_config_creation():
"""Test policy config creation through factory."""
print("\nTesting config creation through factory...")
try:
config = make_policy_config(
policy_type="pi0_openpi",
max_action_dim=7,
max_state_dim=14,
)
print("Config created successfully through factory")
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}")
except Exception as e:
print(f"Config creation failed: {e}")
print(f"Config creation failed: {e}")
raise