Allow Groot fake RTC chunk prefetch

This commit is contained in:
Andrew Wrenn
2026-06-02 14:20:00 -07:00
parent 342d223706
commit 364750ada2
2 changed files with 13 additions and 1 deletions
+5 -1
View File
@@ -342,10 +342,14 @@ class GrootPolicy(PreTrainedPolicy):
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
Groot does not currently implement LeRobot's RTC guidance contract. Accept
and ignore action-selection kwargs so the RTC engine can still use Groot as
an async chunk producer.
"""
self.eval()
+8
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import sys
from types import SimpleNamespace
@@ -419,6 +420,13 @@ def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_
assert "gr00t" not in sys.modules
def test_groot_predict_action_chunk_accepts_rtc_kwargs():
signature = inspect.signature(GrootPolicy.predict_action_chunk)
assert any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None)
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()