From 364750ada2273b97a76c4bbfced5b4f263180a43 Mon Sep 17 00:00:00 2001 From: Andrew Wrenn Date: Tue, 2 Jun 2026 14:20:00 -0700 Subject: [PATCH] Allow Groot fake RTC chunk prefetch --- src/lerobot/policies/groot/modeling_groot.py | 6 +++++- tests/policies/groot/test_groot_n1_7.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index a28d0c148..c6068c875 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -342,10 +342,14 @@ class GrootPolicy(PreTrainedPolicy): return loss, loss_dict @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor: """Predict a chunk of actions for inference by delegating to Isaac-GR00T. Returns a tensor of shape (B, n_action_steps, action_dim). + + Groot does not currently implement LeRobot's RTC guidance contract. Accept + and ignore action-selection kwargs so the RTC engine can still use Groot as + an async chunk producer. """ self.eval() diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 41b6ab8e1..9a6061f12 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import sys from types import SimpleNamespace @@ -419,6 +420,13 @@ def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_ assert "gr00t" not in sys.modules +def test_groot_predict_action_chunk_accepts_rtc_kwargs(): + signature = inspect.signature(GrootPolicy.predict_action_chunk) + + assert any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values()) + signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None) + + def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): model_path = tmp_path / "GR00T-N1.7-local" model_path.mkdir()