mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix test
This commit is contained in:
+2
-1
@@ -135,7 +135,8 @@ def test_pi05_forward_pass():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = policy.select_action(batch)
|
action = policy.select_action(batch)
|
||||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||||
assert action.shape == (7,), f"Expected action shape (7,), got {action.shape}"
|
# When batch_size > 1, select_action returns (batch_size, action_dim)
|
||||||
|
assert action.shape == (batch_size, 7), f"Expected action shape ({batch_size}, 7), got {action.shape}"
|
||||||
assert not torch.isnan(action).any(), "Action contains NaN values"
|
assert not torch.isnan(action).any(), "Action contains NaN values"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Action prediction failed: {e}")
|
print(f"✗ Action prediction failed: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user