From 3159f473df9f6c95295a0159eb221123a1639f65 Mon Sep 17 00:00:00 2001 From: Andrew Wrenn Date: Wed, 3 Jun 2026 13:51:35 -0700 Subject: [PATCH] Trim GR00T N1.7 RTC chunks to valid horizon --- src/lerobot/policies/groot/modeling_groot.py | 24 ++++++++++++++ tests/policies/groot/test_groot_n1_7.py | 34 ++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 9e6c0ac7e..1c03043c2 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -292,6 +292,27 @@ class GrootPolicy(PreTrainedPolicy): horizons.append(execution_horizon) return min(horizons) + def _resolve_prediction_horizon(self, actions: Tensor) -> int: + """Return the policy-facing action horizon for a native GR00T prediction.""" + + if self.config.model_version != GROOT_N1_7: + return actions.shape[1] + + horizons = [actions.shape[1]] + checkpoint_action_horizon = infer_groot_n1_7_action_horizon( + self.config.base_model_path, + self.config.embodiment_tag, + ) + if checkpoint_action_horizon is not None: + horizons.append(checkpoint_action_horizon) + + for horizon in (self.config.chunk_size, self.config.n_action_steps): + horizon = int(horizon) + if horizon > 0: + horizons.append(horizon) + + return max(1, min(horizons)) + def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]: allowed_base = {"state", "state_mask", "embodiment_id"} if include_action: @@ -455,6 +476,9 @@ class GrootPolicy(PreTrainedPolicy): actions = outputs.get("action_pred") + prediction_horizon = self._resolve_prediction_horizon(actions) + actions = actions[:, :prediction_horizon] + original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index a804950d8..122079819 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -458,6 +458,40 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125)) +def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + + class HorizonModel(_DummyGrootModel): + def get_action(self, inputs, options=None): + del options + batch_size = inputs["state"].shape[0] + steps = torch.arange(40, dtype=torch.float32).view(1, 40, 1).expand(batch_size, 40, 132) + return {"action_pred": steps} + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel())) + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + config = GrootConfig( + model_version=GROOT_N1_7, + base_model_path=str(model_path), + embodiment_tag="libero_sim", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + chunk_size=40, + n_action_steps=40, + ) + policy = GrootPolicy(config) + + actions = policy.predict_action_chunk({"state": torch.zeros(1, 1, 132)}) + + assert actions.shape == (1, 16, 7) + torch.testing.assert_close(actions[0, :, 0], torch.arange(16, dtype=torch.float32)) + + def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): model_path = tmp_path / "GR00T-N1.7-local" model_path.mkdir()