Trim GR00T N1.7 RTC chunks to valid horizon

This commit is contained in:
Andrew Wrenn
2026-06-03 13:51:35 -07:00
committed by Andy Wrenn
parent bed3747804
commit 3159f473df
2 changed files with 58 additions and 0 deletions
@@ -292,6 +292,27 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(execution_horizon)
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
if self.config.model_version != GROOT_N1_7:
return actions.shape[1]
horizons = [actions.shape[1]]
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
for horizon in (self.config.chunk_size, self.config.n_action_steps):
horizon = int(horizon)
if horizon > 0:
horizons.append(horizon)
return max(1, min(horizons))
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
allowed_base = {"state", "state_mask", "embodiment_id"}
if include_action:
@@ -455,6 +476,9 @@ class GrootPolicy(PreTrainedPolicy):
actions = outputs.get("action_pred")
prediction_horizon = self._resolve_prediction_horizon(actions)
actions = actions[:, :prediction_horizon]
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
+34
View File
@@ -458,6 +458,40 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17
model_path = tmp_path / "libero_spatial"
_write_raw_n1_7_libero_checkpoint(model_path)
class HorizonModel(_DummyGrootModel):
def get_action(self, inputs, options=None):
del options
batch_size = inputs["state"].shape[0]
steps = torch.arange(40, dtype=torch.float32).view(1, 40, 1).expand(batch_size, 40, 132)
return {"action_pred": steps}
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path=str(model_path),
embodiment_tag="libero_sim",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
chunk_size=40,
n_action_steps=40,
)
policy = GrootPolicy(config)
actions = policy.predict_action_chunk({"state": torch.zeros(1, 1, 132)})
assert actions.shape == (1, 16, 7)
torch.testing.assert_close(actions[0, :, 0], torch.arange(16, dtype=torch.float32))
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()