Add test to instatiate all base models

This commit is contained in:
Pepijn
2025-09-16 13:31:29 +02:00
parent 6aaeb7c13f
commit 0e0d6fbfc2
2 changed files with 146 additions and 6 deletions
+3 -6
View File
@@ -51,7 +51,7 @@ def test_policy_instantiation():
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"✗ Forward pass failed: {e}")
return False
raise
print("\nTesting action prediction...")
try:
@@ -60,9 +60,7 @@ def test_policy_instantiation():
print(f"✓ Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"✗ Action prediction failed: {e}")
return False
return True
raise
@require_nightly_gpu
@@ -80,7 +78,6 @@ def test_config_creation():
print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}")
return True
except Exception as e:
print(f"✗ Config creation failed: {e}")
return False
raise