From a35e6a4b46c33ccf46273fa655b98ef81506924a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 30 Jun 2026 14:31:49 +0200 Subject: [PATCH] chore(policies): add guards, warnings and comments + recover tests n1.5 check --- .../policies/groot/configuration_groot.py | 9 +++- src/lerobot/policies/groot/processor_groot.py | 36 +++++++++++++ tests/policies/groot/test_groot_n1_7.py | 50 +++++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index cd803f31a..c64c6143f 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -324,9 +324,14 @@ class GrootConfig(PreTrainedConfig): # Set to True only after installing a flash-attn build matching your torch/CUDA env. use_flash_attention: bool = False - # Enable GR00T-style state-relative action chunks. Prefer deriving action representation from - # embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it. + # Enable GR00T-style state-relative action chunks (action chunk expressed relative to the current + # observation state). use_relative_actions: bool = False + + # relative_exclude_joints names the action dimensions that stay absolute; the + # match is substring/case-insensitive against the dataset action feature names. With the empty + # default every dimension is treated as relative, including the gripper -- set e.g. ["gripper"] to + # keep the gripper absolute, matching the Isaac-GR00T single-arm + absolute-gripper convention. relative_exclude_joints: list[str] = field(default_factory=list) # Training parameters diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 205f0a2b6..63b0166e5 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -996,6 +996,7 @@ def _build_n1_7_relative_action_processor_assets( } for group in groups ] + # 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B) action_horizon = min(config.chunk_size, 40) modality_config: dict[str, Any] = { "state": {"modality_keys": [group.key for group in groups]}, @@ -1194,6 +1195,13 @@ def make_groot_pre_post_processors( ) relative_step: RelativeActionsProcessorStep | None = None if config.use_relative_actions and not uses_native_relative_actions: + logging.warning( + "GR00T relative actions are using the generic RelativeActionsProcessorStep fallback because " + "the checkpoint already carries non-relative statistics. Relative deltas will be normalized " + "with absolute action stats rather than Isaac-GR00T's per-horizon relative stats. For " + "OSS-faithful relative normalization, build from a checkpoint without baked-in stats (or " + "pass dataset_meta) so native relative stats are computed." + ) relative_step = RelativeActionsProcessorStep( enabled=True, exclude_joints=list(config.relative_exclude_joints or []), @@ -1658,6 +1666,25 @@ class GrootN17PackInputsStep(ProcessorStep): return None return torch.cat(normalized_groups, dim=-1) + def _uses_relative_action_groups(self) -> bool: + """True when the action modality declares at least one relative group. + + Relative groups normalize with per-chunk-timestep (2D) ``relative_action`` stats, which the + flat ``_min_max_norm`` fallback cannot honor, so a relative config that fails grouped + normalization must fail loudly rather than silently mis-scale every timestep. + """ + if not isinstance(self.modality_config, dict): + return False + action_config = self.modality_config.get("action", {}) + if not isinstance(action_config, dict): + return False + action_configs = action_config.get("action_configs", []) + if not isinstance(action_configs, list): + return False + return any( + isinstance(cfg, dict) and config_value(cfg.get("rep")) == "relative" for cfg in action_configs + ) + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition.get(TransitionKey.OBSERVATION, {}) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} @@ -1775,6 +1802,15 @@ class GrootN17PackInputsStep(ProcessorStep): normalized_action = self._normalize_action_groups_for_training(action) if normalized_action is not None: action = normalized_action + elif self._uses_relative_action_groups(): + raise ValueError( + "GrootN17PackInputsStep could not apply native grouped normalization to a " + "relative-action chunk: the action layout or horizon does not match the " + f"checkpoint relative_action stats (action shape {tuple(action.shape)}). The flat " + "min/max fallback cannot honor per-chunk-timestep relative stats, so refusing to " + "silently mis-normalize. Recompute the relative action stats so their horizon and " + "dimensions match the action chunk." + ) else: flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION) action = flat.view(bsz, horizon, dim) diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 7a1a5af26..d8f8cfbc6 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -30,10 +30,12 @@ from lerobot.configs import FeatureType, PolicyFeature from lerobot.policies.factory import make_policy_config, make_pre_post_processors from lerobot.policies.groot.configuration_groot import ( GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + GROOT_N1_7, GROOT_N1_7_BASE_MODEL, GrootConfig, infer_groot_n1_7_action_execution_horizon, infer_groot_n1_7_action_horizon, + normalize_groot_model_version, ) from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import ( @@ -350,6 +352,18 @@ def test_groot_defaults_use_n1_7(): assert len(config.action_delta_indices) == 40 +@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"]) +def test_groot_normalize_model_version_rejects_n1_5_aliases(legacy_version): + # model_version is no longer a GrootConfig field, but normalize_groot_model_version is still + # live (e.g. via infer_groot_model_version) and must keep rejecting N1.5 with removal guidance. + with pytest.raises(ValueError, match="Unsupported GR00T model_version"): + normalize_groot_model_version(legacy_version) + + +def test_groot_normalize_model_version_accepts_n1_7(): + assert normalize_groot_model_version(GROOT_N1_7) == GROOT_N1_7 + + def test_groot_n1_7_accepts_named_action_decode_transform(): config = GrootConfig( action_decode_transform="libero", @@ -997,6 +1011,42 @@ def test_groot_n1_7_pack_inputs_normalizes_action_chunk_per_dimension_before_pad assert action_mask[0, :, 3:].sum().item() == 0 +def test_groot_n1_7_pack_inputs_raises_when_relative_groups_cannot_normalize(): + # Relative groups carry per-chunk-timestep stats; if the action horizon exceeds the available + # stat rows, grouped normalization cannot apply and the flat fallback would silently mis-scale. + step = GrootN17PackInputsStep( + action_horizon=3, + valid_action_horizon=3, + max_state_dim=2, + max_action_dim=2, + normalize_min_max=True, + raw_stats={ + "state": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}}, + "action": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}}, + # only one horizon row, but the action chunk has horizon 3 + "relative_action": {"single_arm": {"min": [[-1.0, -1.0]], "max": [[1.0, 1.0]]}}, + }, + modality_config={ + "state": {"modality_keys": ["single_arm"]}, + "action": { + "modality_keys": ["single_arm"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None} + ], + "delta_indices": [0, 1, 2], + }, + }, + ) + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 2)}, + TransitionKey.ACTION: torch.zeros(1, 3, 2), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + with pytest.raises(ValueError, match="could not apply native grouped normalization"): + step(transition) + + def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_gripper(): step = GrootN17PackInputsStep( action_horizon=2,