fix bug for inference

This commit is contained in:
Geoffrey19
2025-12-16 13:21:04 +08:00
committed by Michel Aractingi
parent fc6262e23d
commit 51bd288f1a
2 changed files with 20 additions and 1 deletions
+9
View File
@@ -99,6 +99,15 @@ def test_policy_instantiation():
print(f"Forward pass failed: {e}")
raise
# Test inference
batch = {
"observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device),
"observation.images.face_view": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
with torch.no_grad():
action = policy.select_action(batch)