mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 09:07:03 +00:00
Trim GR00T N1.7 RTC chunks to valid horizon
This commit is contained in:
@@ -292,6 +292,27 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
horizons.append(execution_horizon)
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
|
||||
if self.config.model_version != GROOT_N1_7:
|
||||
return actions.shape[1]
|
||||
|
||||
horizons = [actions.shape[1]]
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
|
||||
for horizon in (self.config.chunk_size, self.config.n_action_steps):
|
||||
horizon = int(horizon)
|
||||
if horizon > 0:
|
||||
horizons.append(horizon)
|
||||
|
||||
return max(1, min(horizons))
|
||||
|
||||
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
if include_action:
|
||||
@@ -455,6 +476,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
prediction_horizon = self._resolve_prediction_horizon(actions)
|
||||
actions = actions[:, :prediction_horizon]
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
|
||||
@@ -458,6 +458,40 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
|
||||
|
||||
|
||||
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
_write_raw_n1_7_libero_checkpoint(model_path)
|
||||
|
||||
class HorizonModel(_DummyGrootModel):
|
||||
def get_action(self, inputs, options=None):
|
||||
del options
|
||||
batch_size = inputs["state"].shape[0]
|
||||
steps = torch.arange(40, dtype=torch.float32).view(1, 40, 1).expand(batch_size, 40, 132)
|
||||
return {"action_pred": steps}
|
||||
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path=str(model_path),
|
||||
embodiment_tag="libero_sim",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
chunk_size=40,
|
||||
n_action_steps=40,
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
actions = policy.predict_action_chunk({"state": torch.zeros(1, 1, 132)})
|
||||
|
||||
assert actions.shape == (1, 16, 7)
|
||||
torch.testing.assert_close(actions[0, :, 0], torch.arange(16, dtype=torch.float32))
|
||||
|
||||
|
||||
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
model_path.mkdir()
|
||||
|
||||
Reference in New Issue
Block a user