From de9af574758248709ecd016690d14d7926c6db35 Mon Sep 17 00:00:00 2001 From: Andrew Wrenn Date: Wed, 3 Jun 2026 13:43:13 -0700 Subject: [PATCH] Fix GR00T N1.7 RTC action decoding --- src/lerobot/policies/groot/modeling_groot.py | 98 +++++++++++++- src/lerobot/policies/groot/processor_groot.py | 49 ++++++- tests/policies/groot/test_groot_n1_7.py | 124 +++++++++++++++++- 3 files changed, 259 insertions(+), 12 deletions(-) diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index c6068c875..9e6c0ac7e 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -319,6 +319,85 @@ class GrootPolicy(PreTrainedPolicy): if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info") } + def _prepare_n1_7_rtc_inputs( + self, + inputs: dict[str, Tensor], + *, + inference_delay: object, + prev_chunk_left_over: object, + ) -> tuple[dict[str, Tensor], dict[str, object] | None]: + if self.config.model_version != GROOT_N1_7 or prev_chunk_left_over is None: + return inputs, None + if not isinstance(prev_chunk_left_over, torch.Tensor): + raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.") + if prev_chunk_left_over.numel() == 0: + return inputs, None + + prev_actions = prev_chunk_left_over + if prev_actions.ndim == 2: + prev_actions = prev_actions.unsqueeze(0) + elif prev_actions.ndim != 3: + raise ValueError( + "prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC." + ) + + state = inputs.get("state") + if state is None: + raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.") + batch_size = state.shape[0] + if prev_actions.shape[0] == 1 and batch_size > 1: + prev_actions = prev_actions.expand(batch_size, -1, -1).clone() + elif prev_actions.shape[0] != batch_size: + raise ValueError( + "prev_chunk_left_over batch size must match the current GR00T N1.7 batch size." + ) + + model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)) + max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim)) + if prev_actions.shape[1] > model_action_horizon: + prev_actions = prev_actions[:, -model_action_horizon:, :] + + action_horizon = int(prev_actions.shape[1]) + if action_horizon <= 0: + return inputs, None + + if prev_actions.shape[2] > max_action_dim: + prev_actions = prev_actions[:, :, :max_action_dim] + elif prev_actions.shape[2] < max_action_dim: + pad = torch.zeros( + prev_actions.shape[0], + prev_actions.shape[1], + max_action_dim - prev_actions.shape[2], + dtype=prev_actions.dtype, + device=prev_actions.device, + ) + prev_actions = torch.cat([prev_actions, pad], dim=2) + + prev_actions = prev_actions.to(device=state.device, dtype=state.dtype) + + rtc_config = getattr(self.config, "rtc_config", None) + execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon)) + overlap_steps = max(0, min(action_horizon, execution_horizon)) + if overlap_steps == 0: + return inputs, None + + try: + frozen_steps = int(inference_delay or 0) + except (TypeError, ValueError): + frozen_steps = 0 + frozen_steps = max(0, min(frozen_steps, overlap_steps)) + + options = { + "action_horizon": action_horizon, + "rtc_overlap_steps": overlap_steps, + "rtc_frozen_steps": frozen_steps, + "rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)), + } + + inputs = dict(inputs) + inputs["action"] = prev_actions + return inputs, options + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Training forward pass. @@ -342,14 +421,13 @@ class GrootPolicy(PreTrainedPolicy): return loss, loss_dict @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor: """Predict a chunk of actions for inference by delegating to Isaac-GR00T. Returns a tensor of shape (B, n_action_steps, action_dim). - Groot does not currently implement LeRobot's RTC guidance contract. Accept - and ignore action-selection kwargs so the RTC engine can still use Groot as - an async chunk producer. + For N1.7, LeRobot's RTC leftovers are converted into the native GR00T + action-overlap options before calling the underlying model. """ self.eval() @@ -357,13 +435,23 @@ class GrootPolicy(PreTrainedPolicy): # During inference, we do not pass action because it is predicted. # N1.7 still carries a 2-D action horizon mask from its checkpoint processor. groot_inputs = self._filter_groot_inputs(batch, include_action=False) + groot_options = None + if self.config.model_version == GROOT_N1_7: + groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs( + groot_inputs, + inference_delay=kwargs.get("inference_delay"), + prev_chunk_left_over=kwargs.get("prev_chunk_left_over"), + ) # Get device from model parameters device = next(self.parameters()).device # Use bf16 autocast for inference to keep memory low and match backbone dtype with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16): - outputs = self._groot_model.get_action(groot_inputs) + if groot_options is not None: + outputs = self._groot_model.get_action(groot_inputs, options=groot_options) + else: + outputs = self._groot_model.get_action(groot_inputs) actions = outputs.get("action_pred") diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 03e0f7877..594948584 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -494,6 +494,17 @@ def _legacy_groot_processor_overrides( return preprocessor_overrides, postprocessor_overrides +def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool: + path = Path(pretrained_path).expanduser() + if not path.is_dir(): + return False + config = _read_json(path / config_filename) + steps = config.get("steps", []) + if not isinstance(steps, list): + return False + return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps) + + def make_groot_pre_post_processors_from_pretrained( config: GrootConfig, pretrained_path: str, @@ -517,12 +528,26 @@ def make_groot_pre_post_processors_from_pretrained( dataset_stats=dataset_stats, ) - preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides( - config=config, - dataset_stats=dataset_stats, - preprocessor_overrides=preprocessor_overrides, - postprocessor_overrides=postprocessor_overrides, - ) + if ( + config.model_version == GROOT_N1_7 + and _local_processor_config_has_step( + pretrained_path, + postprocessor_config_filename, + "groot_n1_7_action_decode_v1", + ) + ): + # Converted raw N1.7 checkpoints already carry the checkpoint-specific + # action decoder. Adding the legacy action-unpack override would target + # a step that is not present and break loading. + preprocessor_overrides = dict(preprocessor_overrides or {}) + postprocessor_overrides = dict(postprocessor_overrides or {}) + else: + preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides( + config=config, + dataset_stats=dataset_stats, + preprocessor_overrides=preprocessor_overrides, + postprocessor_overrides=postprocessor_overrides, + ) preprocessor = PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, config_filename=preprocessor_config_filename, @@ -1712,6 +1737,15 @@ def _unnormalize_min_max(action: np.ndarray, min_v: np.ndarray, max_v: np.ndarra return (np.clip(action, -1.0, 1.0) + 1.0) * 0.5 * (max_v - min_v) + min_v +def _n1_7_decode_valid_horizon(action_config: dict[str, Any], action_np: np.ndarray) -> int | None: + if action_np.ndim != 3: + return None + delta_indices = action_config.get("delta_indices", []) + if not isinstance(delta_indices, list) or not delta_indices: + return None + return max(1, min(action_np.shape[1], len(delta_indices))) + + def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray: rows = rot6d.reshape(2, 3).astype(np.float64) row1 = rows[0] / np.linalg.norm(rows[0]) @@ -1824,6 +1858,9 @@ class GrootN17ActionDecodeStep(ProcessorStep): return transition action_np = action.detach().cpu().float().numpy() + valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np) + if valid_horizon is not None: + action_np = action_np[:, :valid_horizon] decoded_groups: dict[str, np.ndarray] = {} start_idx = 0 for idx, key in enumerate(action_keys): diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 9a6061f12..a804950d8 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -334,13 +334,15 @@ class _DummyGrootModel(nn.Module): self.config = SimpleNamespace(compute_dtype="float32") self.compute_dtype = "float32" self.forward_inputs = None + self.get_action_options = None def forward(self, inputs): self.forward_inputs = dict(inputs) return {"loss": self.weight + 1.0} - def get_action(self, inputs): + def get_action(self, inputs, options=None): self.forward_inputs = dict(inputs) + self.get_action_options = options batch_size = inputs["state"].shape[0] return {"action_pred": torch.zeros(batch_size, 40, 132, device=self.weight.device)} @@ -427,6 +429,35 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs(): signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None) +def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + dummy_model = _DummyGrootModel() + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) + config = _groot_config(GROOT_N1_7) + policy = GrootPolicy(config) + policy.config.rtc_config = SimpleNamespace(execution_horizon=6) + + prev_chunk = torch.arange(8 * 7, dtype=torch.float32).view(8, 7) + + actions = policy.predict_action_chunk( + {"state": torch.zeros(1, 1, 132)}, + inference_delay=3, + prev_chunk_left_over=prev_chunk, + ) + + assert actions.shape == (1, 40, 7) + assert dummy_model.get_action_options == { + "action_horizon": 8, + "rtc_overlap_steps": 6, + "rtc_frozen_steps": 3, + "rtc_ramp_rate": 6.0, + } + assert dummy_model.forward_inputs["action"].shape == (1, 8, 132) + torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk) + torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125)) + + def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): model_path = tmp_path / "GR00T-N1.7-local" model_path.mkdir() @@ -593,6 +624,27 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0] +def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_override(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "converted_pretrained_model" + + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(save_dir), + preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}}, + ) + + assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps) + assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps) + + def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max(): step = GrootN17PackInputsStep( max_state_dim=2, @@ -941,6 +993,76 @@ def test_groot_n1_7_action_decode_applies_named_libero_transform_from_modality_k torch.testing.assert_close(output[TransitionKey.ACTION], expected) +def test_groot_n1_7_action_decode_truncates_to_valid_horizon_for_relative_stats(): + arm_min = [[-1.0] * 5 for _ in range(16)] + arm_max = [[1.0] * 5 for _ in range(16)] + raw_stats = { + "state": { + "single_arm": _stats([0.0] * 5), + "gripper": _stats([0.0]), + }, + "action": { + "single_arm": _stats([0.0] * 5), + "gripper": { + "min": [0.0], + "max": [10.0], + "mean": [5.0], + "std": [1.0], + "q01": [0.0], + "q99": [10.0], + }, + }, + "relative_action": { + "single_arm": { + "min": arm_min, + "max": arm_max, + "mean": [[0.0] * 5 for _ in range(16)], + "std": [[1.0] * 5 for _ in range(16)], + "q01": arm_min, + "q99": arm_max, + }, + }, + } + modality_config = { + "state": { + "modality_keys": ["single_arm", "gripper"], + }, + "action": { + "delta_indices": list(range(16)), + "modality_keys": ["single_arm", "gripper"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + {"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + ], + }, + } + pack_step = GrootN17PackInputsStep( + raw_stats=raw_stats, + modality_config=modality_config, + normalize_min_max=False, + ) + pack_step( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 6)}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + ) + decode_step = GrootN17ActionDecodeStep( + env_action_dim=6, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack_step, + ) + + output = decode_step({TransitionKey.ACTION: torch.zeros(1, 40, 6)}) + + decoded = output[TransitionKey.ACTION] + assert decoded.shape == (1, 16, 6) + torch.testing.assert_close(decoded[..., :5], torch.zeros(1, 16, 5)) + torch.testing.assert_close(decoded[..., 5], torch.full((1, 16), 5.0)) + + def test_groot_n1_7_action_decode_requires_gripper_key_for_libero_transform(): step = GrootN17ActionDecodeStep( env_action_dim=1,