This commit is contained in:
Pepijn
2025-09-10 21:41:05 +02:00
parent e9e7eb827a
commit 21e63b505f
+2 -1
View File
@@ -135,7 +135,8 @@ def test_pi05_forward_pass():
with torch.no_grad():
action = policy.select_action(batch)
print(f"✓ Action prediction successful. Action shape: {action.shape}")
assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}"
# When batch_size > 1, select_action returns (batch_size, action_dim)
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
assert not torch.isnan(action).any(), "Action contains NaN values"
except Exception as e:
print(f"✗ Action prediction failed: {e}")