mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Ignore padded GR00T N1.7 RTC prefix rows
This commit is contained in:
@@ -373,6 +373,17 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
||||
)
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
# provided prefix row as a real action constraint, so strip that padding
|
||||
# before constructing the native overlap options.
|
||||
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
|
||||
if valid_prefix_rows.any():
|
||||
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
|
||||
prev_actions = prev_actions[:, :valid_prefix_steps, :]
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
|
||||
@@ -458,6 +458,39 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
policy = GrootPolicy(config)
|
||||
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
|
||||
|
||||
prev_chunk = torch.cat(
|
||||
(
|
||||
torch.arange(4 * 7, dtype=torch.float32).view(4, 7) + 1.0,
|
||||
torch.zeros(2, 7),
|
||||
)
|
||||
)
|
||||
|
||||
policy.predict_action_chunk(
|
||||
{"state": torch.zeros(1, 1, 132)},
|
||||
inference_delay=5,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
)
|
||||
|
||||
assert dummy_model.get_action_options == {
|
||||
"action_horizon": 4,
|
||||
"rtc_overlap_steps": 4,
|
||||
"rtc_frozen_steps": 4,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
assert dummy_model.forward_inputs["action"].shape == (1, 4, 132)
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk[:4])
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(4, 125))
|
||||
|
||||
|
||||
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
|
||||
Reference in New Issue
Block a user