mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix test
This commit is contained in:
+2
-1
@@ -135,7 +135,8 @@ def test_pi05_forward_pass():
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||
assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}"
|
||||
# When batch_size > 1, select_action returns (batch_size, action_dim)
|
||||
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
|
||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user