From 21e63b505f473baa55b642a89770b3fae42549e1 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 10 Sep 2025 21:41:05 +0200 Subject: [PATCH] fix test --- test_pi05_openpi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test_pi05_openpi.py b/test_pi05_openpi.py index 4393d088c..aeb081a24 100644 --- a/test_pi05_openpi.py +++ b/test_pi05_openpi.py @@ -135,7 +135,8 @@ def test_pi05_forward_pass(): with torch.no_grad(): action = policy.select_action(batch) print(f"✓ Action prediction successful. Action shape: {action.shape}") - assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}" + # When batch_size > 1, select_action returns (batch_size, action_dim) + assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}" assert not torch.isnan(action).any(), "Action contains NaN values" except Exception as e: print(f"✗ Action prediction failed: {e}")