diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 1c03043c2..fe3cc4059 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -373,6 +373,17 @@ class GrootPolicy(PreTrainedPolicy): "prev_chunk_left_over batch size must match the current GR00T N1.7 batch size." ) + # The generic LeRobot RTC engine pads short leftovers with exact zero + # rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every + # provided prefix row as a real action constraint, so strip that padding + # before constructing the native overlap options. + valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0 + if valid_prefix_rows.any(): + valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1 + prev_actions = prev_actions[:, :valid_prefix_steps, :] + else: + return inputs, None + model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)) max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim)) if prev_actions.shape[1] > model_action_horizon: diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 122079819..a4fd365ca 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -458,6 +458,39 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125)) +def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + dummy_model = _DummyGrootModel() + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) + config = _groot_config(GROOT_N1_7) + policy = GrootPolicy(config) + policy.config.rtc_config = SimpleNamespace(execution_horizon=6) + + prev_chunk = torch.cat( + ( + torch.arange(4 * 7, dtype=torch.float32).view(4, 7) + 1.0, + torch.zeros(2, 7), + ) + ) + + policy.predict_action_chunk( + {"state": torch.zeros(1, 1, 132)}, + inference_delay=5, + prev_chunk_left_over=prev_chunk, + ) + + assert dummy_model.get_action_options == { + "action_horizon": 4, + "rtc_overlap_steps": 4, + "rtc_frozen_steps": 4, + "rtc_ramp_rate": 6.0, + } + assert dummy_model.forward_inputs["action"].shape == (1, 4, 132) + torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk[:4]) + torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(4, 125)) + + def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch): from lerobot.policies.groot.groot_n1_7 import GR00TN17