chore(policies): add guards, warnings and comments + recover tests n1.5 check

This commit is contained in:
Steven Palma
2026-06-30 14:31:49 +02:00
parent 4a3f46d0ec
commit a35e6a4b46
3 changed files with 93 additions and 2 deletions
@@ -324,9 +324,14 @@ class GrootConfig(PreTrainedConfig):
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
use_flash_attention: bool = False
# Enable GR00T-style state-relative action chunks. Prefer deriving action representation from
# embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it.
# Enable GR00T-style state-relative action chunks (action chunk expressed relative to the current
# observation state).
use_relative_actions: bool = False
# relative_exclude_joints names the action dimensions that stay absolute; the
# match is substring/case-insensitive against the dataset action feature names. With the empty
# default every dimension is treated as relative, including the gripper -- set e.g. ["gripper"] to
# keep the gripper absolute, matching the Isaac-GR00T single-arm + absolute-gripper convention.
relative_exclude_joints: list[str] = field(default_factory=list)
# Training parameters
@@ -996,6 +996,7 @@ def _build_n1_7_relative_action_processor_assets(
}
for group in groups
]
# 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B)
action_horizon = min(config.chunk_size, 40)
modality_config: dict[str, Any] = {
"state": {"modality_keys": [group.key for group in groups]},
@@ -1194,6 +1195,13 @@ def make_groot_pre_post_processors(
)
relative_step: RelativeActionsProcessorStep | None = None
if config.use_relative_actions and not uses_native_relative_actions:
logging.warning(
"GR00T relative actions are using the generic RelativeActionsProcessorStep fallback because "
"the checkpoint already carries non-relative statistics. Relative deltas will be normalized "
"with absolute action stats rather than Isaac-GR00T's per-horizon relative stats. For "
"OSS-faithful relative normalization, build from a checkpoint without baked-in stats (or "
"pass dataset_meta) so native relative stats are computed."
)
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=list(config.relative_exclude_joints or []),
@@ -1658,6 +1666,25 @@ class GrootN17PackInputsStep(ProcessorStep):
return None
return torch.cat(normalized_groups, dim=-1)
def _uses_relative_action_groups(self) -> bool:
"""True when the action modality declares at least one relative group.
Relative groups normalize with per-chunk-timestep (2D) ``relative_action`` stats, which the
flat ``_min_max_norm`` fallback cannot honor, so a relative config that fails grouped
normalization must fail loudly rather than silently mis-scale every timestep.
"""
if not isinstance(self.modality_config, dict):
return False
action_config = self.modality_config.get("action", {})
if not isinstance(action_config, dict):
return False
action_configs = action_config.get("action_configs", [])
if not isinstance(action_configs, list):
return False
return any(
isinstance(cfg, dict) and config_value(cfg.get("rep")) == "relative" for cfg in action_configs
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
@@ -1775,6 +1802,15 @@ class GrootN17PackInputsStep(ProcessorStep):
normalized_action = self._normalize_action_groups_for_training(action)
if normalized_action is not None:
action = normalized_action
elif self._uses_relative_action_groups():
raise ValueError(
"GrootN17PackInputsStep could not apply native grouped normalization to a "
"relative-action chunk: the action layout or horizon does not match the "
f"checkpoint relative_action stats (action shape {tuple(action.shape)}). The flat "
"min/max fallback cannot honor per-chunk-timestep relative stats, so refusing to "
"silently mis-normalize. Recompute the relative action stats so their horizon and "
"dimensions match the action chunk."
)
else:
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
action = flat.view(bsz, horizon, dim)
+50
View File
@@ -30,10 +30,12 @@ from lerobot.configs import FeatureType, PolicyFeature
from lerobot.policies.factory import make_policy_config, make_pre_post_processors
from lerobot.policies.groot.configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_7,
GROOT_N1_7_BASE_MODEL,
GrootConfig,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import (
@@ -350,6 +352,18 @@ def test_groot_defaults_use_n1_7():
assert len(config.action_delta_indices) == 40
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
def test_groot_normalize_model_version_rejects_n1_5_aliases(legacy_version):
# model_version is no longer a GrootConfig field, but normalize_groot_model_version is still
# live (e.g. via infer_groot_model_version) and must keep rejecting N1.5 with removal guidance.
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
normalize_groot_model_version(legacy_version)
def test_groot_normalize_model_version_accepts_n1_7():
assert normalize_groot_model_version(GROOT_N1_7) == GROOT_N1_7
def test_groot_n1_7_accepts_named_action_decode_transform():
config = GrootConfig(
action_decode_transform="libero",
@@ -997,6 +1011,42 @@ def test_groot_n1_7_pack_inputs_normalizes_action_chunk_per_dimension_before_pad
assert action_mask[0, :, 3:].sum().item() == 0
def test_groot_n1_7_pack_inputs_raises_when_relative_groups_cannot_normalize():
# Relative groups carry per-chunk-timestep stats; if the action horizon exceeds the available
# stat rows, grouped normalization cannot apply and the flat fallback would silently mis-scale.
step = GrootN17PackInputsStep(
action_horizon=3,
valid_action_horizon=3,
max_state_dim=2,
max_action_dim=2,
normalize_min_max=True,
raw_stats={
"state": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
"action": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
# only one horizon row, but the action chunk has horizon 3
"relative_action": {"single_arm": {"min": [[-1.0, -1.0]], "max": [[1.0, 1.0]]}},
},
modality_config={
"state": {"modality_keys": ["single_arm"]},
"action": {
"modality_keys": ["single_arm"],
"action_configs": [
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}
],
"delta_indices": [0, 1, 2],
},
},
)
transition = {
TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 2)},
TransitionKey.ACTION: torch.zeros(1, 3, 2),
TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]},
}
with pytest.raises(ValueError, match="could not apply native grouped normalization"):
step(transition)
def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_gripper():
step = GrootN17PackInputsStep(
action_horizon=2,