diff --git a/test_pi0_openpi.py b/test_pi0_openpi.py index d7e5ffc0b..ad8419f7a 100644 --- a/test_pi0_openpi.py +++ b/test_pi0_openpi.py @@ -13,7 +13,7 @@ def test_policy_instantiation(): print("Testing PI0OpenPI policy instantiation...") # Create config - config = PI0OpenPIConfig(action_dim=7, state_dim=14, device="cpu", dtype="float32") + config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32") # Create dummy dataset stats dataset_stats = {