Ignore padded GR00T N1.7 RTC prefix rows

This commit is contained in:
Andrew Wrenn
2026-06-03 14:04:31 -07:00
parent 1d6810b814
commit 6caeac9d07
2 changed files with 44 additions and 0 deletions
@@ -373,6 +373,17 @@ class GrootPolicy(PreTrainedPolicy):
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
# provided prefix row as a real action constraint, so strip that padding
# before constructing the native overlap options.
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
if valid_prefix_rows.any():
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
prev_actions = prev_actions[:, :valid_prefix_steps, :]
else:
return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
+33
View File
@@ -458,6 +458,39 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17
dummy_model = _DummyGrootModel()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
config = _groot_config(GROOT_N1_7)
policy = GrootPolicy(config)
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
prev_chunk = torch.cat(
(
torch.arange(4 * 7, dtype=torch.float32).view(4, 7) + 1.0,
torch.zeros(2, 7),
)
)
policy.predict_action_chunk(
{"state": torch.zeros(1, 1, 132)},
inference_delay=5,
prev_chunk_left_over=prev_chunk,
)
assert dummy_model.get_action_options == {
"action_horizon": 4,
"rtc_overlap_steps": 4,
"rtc_frozen_steps": 4,
"rtc_ramp_rate": 6.0,
}
assert dummy_model.forward_inputs["action"].shape == (1, 4, 132)
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk[:4])
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(4, 125))
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17