Fix GR00T N1.7 RTC action decoding

This commit is contained in:
Andrew Wrenn
2026-06-03 13:43:13 -07:00
parent 364750ada2
commit de9af57475
3 changed files with 259 additions and 12 deletions
+93 -5
View File
@@ -319,6 +319,85 @@ class GrootPolicy(PreTrainedPolicy):
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
self,
inputs: dict[str, Tensor],
*,
inference_delay: object,
prev_chunk_left_over: object,
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if self.config.model_version != GROOT_N1_7 or prev_chunk_left_over is None:
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
if prev_chunk_left_over.numel() == 0:
return inputs, None
prev_actions = prev_chunk_left_over
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state")
if state is None:
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
batch_size = state.shape[0]
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
action_horizon = int(prev_actions.shape[1])
if action_horizon <= 0:
return inputs, None
if prev_actions.shape[2] > max_action_dim:
prev_actions = prev_actions[:, :, :max_action_dim]
elif prev_actions.shape[2] < max_action_dim:
pad = torch.zeros(
prev_actions.shape[0],
prev_actions.shape[1],
max_action_dim - prev_actions.shape[2],
dtype=prev_actions.dtype,
device=prev_actions.device,
)
prev_actions = torch.cat([prev_actions, pad], dim=2)
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
rtc_config = getattr(self.config, "rtc_config", None)
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
overlap_steps = max(0, min(action_horizon, execution_horizon))
if overlap_steps == 0:
return inputs, None
try:
frozen_steps = int(inference_delay or 0)
except (TypeError, ValueError):
frozen_steps = 0
frozen_steps = max(0, min(frozen_steps, overlap_steps))
options = {
"action_horizon": action_horizon,
"rtc_overlap_steps": overlap_steps,
"rtc_frozen_steps": frozen_steps,
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
}
inputs = dict(inputs)
inputs["action"] = prev_actions
return inputs, options
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
@@ -342,14 +421,13 @@ class GrootPolicy(PreTrainedPolicy):
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
Groot does not currently implement LeRobot's RTC guidance contract. Accept
and ignore action-selection kwargs so the RTC engine can still use Groot as
an async chunk producer.
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
action-overlap options before calling the underlying model.
"""
self.eval()
@@ -357,13 +435,23 @@ class GrootPolicy(PreTrainedPolicy):
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
groot_options = None
if self.config.model_version == GROOT_N1_7:
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
if groot_options is not None:
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
else:
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
+43 -6
View File
@@ -494,6 +494,17 @@ def _legacy_groot_processor_overrides(
return preprocessor_overrides, postprocessor_overrides
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
path = Path(pretrained_path).expanduser()
if not path.is_dir():
return False
config = _read_json(path / config_filename)
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def make_groot_pre_post_processors_from_pretrained(
config: GrootConfig,
pretrained_path: str,
@@ -517,12 +528,26 @@ def make_groot_pre_post_processors_from_pretrained(
dataset_stats=dataset_stats,
)
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
if (
config.model_version == GROOT_N1_7
and _local_processor_config_has_step(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
)
):
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
else:
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=preprocessor_config_filename,
@@ -1712,6 +1737,15 @@ def _unnormalize_min_max(action: np.ndarray, min_v: np.ndarray, max_v: np.ndarra
return (np.clip(action, -1.0, 1.0) + 1.0) * 0.5 * (max_v - min_v) + min_v
def _n1_7_decode_valid_horizon(action_config: dict[str, Any], action_np: np.ndarray) -> int | None:
if action_np.ndim != 3:
return None
delta_indices = action_config.get("delta_indices", [])
if not isinstance(delta_indices, list) or not delta_indices:
return None
return max(1, min(action_np.shape[1], len(delta_indices)))
def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
rows = rot6d.reshape(2, 3).astype(np.float64)
row1 = rows[0] / np.linalg.norm(rows[0])
@@ -1824,6 +1858,9 @@ class GrootN17ActionDecodeStep(ProcessorStep):
return transition
action_np = action.detach().cpu().float().numpy()
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
if valid_horizon is not None:
action_np = action_np[:, :valid_horizon]
decoded_groups: dict[str, np.ndarray] = {}
start_idx = 0
for idx, key in enumerate(action_keys):
+123 -1
View File
@@ -334,13 +334,15 @@ class _DummyGrootModel(nn.Module):
self.config = SimpleNamespace(compute_dtype="float32")
self.compute_dtype = "float32"
self.forward_inputs = None
self.get_action_options = None
def forward(self, inputs):
self.forward_inputs = dict(inputs)
return {"loss": self.weight + 1.0}
def get_action(self, inputs):
def get_action(self, inputs, options=None):
self.forward_inputs = dict(inputs)
self.get_action_options = options
batch_size = inputs["state"].shape[0]
return {"action_pred": torch.zeros(batch_size, 40, 132, device=self.weight.device)}
@@ -427,6 +429,35 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs():
signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None)
def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17
dummy_model = _DummyGrootModel()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
config = _groot_config(GROOT_N1_7)
policy = GrootPolicy(config)
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
prev_chunk = torch.arange(8 * 7, dtype=torch.float32).view(8, 7)
actions = policy.predict_action_chunk(
{"state": torch.zeros(1, 1, 132)},
inference_delay=3,
prev_chunk_left_over=prev_chunk,
)
assert actions.shape == (1, 40, 7)
assert dummy_model.get_action_options == {
"action_horizon": 8,
"rtc_overlap_steps": 6,
"rtc_frozen_steps": 3,
"rtc_ramp_rate": 6.0,
}
assert dummy_model.forward_inputs["action"].shape == (1, 8, 132)
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk)
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()
@@ -593,6 +624,27 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_override(tmp_path):
model_path = tmp_path / "libero_spatial"
_write_raw_n1_7_libero_checkpoint(model_path)
config = _raw_n1_7_libero_config(model_path)
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
save_dir = tmp_path / "converted_pretrained_model"
config.save_pretrained(save_dir)
preprocessor.save_pretrained(save_dir)
postprocessor.save_pretrained(save_dir)
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(save_dir),
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
)
assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps)
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max():
step = GrootN17PackInputsStep(
max_state_dim=2,
@@ -941,6 +993,76 @@ def test_groot_n1_7_action_decode_applies_named_libero_transform_from_modality_k
torch.testing.assert_close(output[TransitionKey.ACTION], expected)
def test_groot_n1_7_action_decode_truncates_to_valid_horizon_for_relative_stats():
arm_min = [[-1.0] * 5 for _ in range(16)]
arm_max = [[1.0] * 5 for _ in range(16)]
raw_stats = {
"state": {
"single_arm": _stats([0.0] * 5),
"gripper": _stats([0.0]),
},
"action": {
"single_arm": _stats([0.0] * 5),
"gripper": {
"min": [0.0],
"max": [10.0],
"mean": [5.0],
"std": [1.0],
"q01": [0.0],
"q99": [10.0],
},
},
"relative_action": {
"single_arm": {
"min": arm_min,
"max": arm_max,
"mean": [[0.0] * 5 for _ in range(16)],
"std": [[1.0] * 5 for _ in range(16)],
"q01": arm_min,
"q99": arm_max,
},
},
}
modality_config = {
"state": {
"modality_keys": ["single_arm", "gripper"],
},
"action": {
"delta_indices": list(range(16)),
"modality_keys": ["single_arm", "gripper"],
"action_configs": [
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
],
},
}
pack_step = GrootN17PackInputsStep(
raw_stats=raw_stats,
modality_config=modality_config,
normalize_min_max=False,
)
pack_step(
{
TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 6)},
TransitionKey.COMPLEMENTARY_DATA: {},
}
)
decode_step = GrootN17ActionDecodeStep(
env_action_dim=6,
raw_stats=raw_stats,
modality_config=modality_config,
use_relative_action=True,
pack_step=pack_step,
)
output = decode_step({TransitionKey.ACTION: torch.zeros(1, 40, 6)})
decoded = output[TransitionKey.ACTION]
assert decoded.shape == (1, 16, 6)
torch.testing.assert_close(decoded[..., :5], torch.zeros(1, 16, 5))
torch.testing.assert_close(decoded[..., 5], torch.full((1, 16), 5.0))
def test_groot_n1_7_action_decode_requires_gripper_key_for_libero_transform():
step = GrootN17ActionDecodeStep(
env_action_dim=1,