fix(groot): skip normalization overrides for training

This commit is contained in:
Steven Palma
2026-06-13 19:51:29 +02:00
parent fcb371eddd
commit 378897800a
+44 -1
View File
@@ -447,6 +447,42 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
(e.g. a typo) is left in place and still raises.
"""
if not overrides:
return overrides
filtered: dict[str, Any] = {}
for key, value in overrides.items():
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
logging.debug(
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
"no matching step (see GrootConfig.normalization_mapping).",
key,
)
continue
filtered[key] = value
return filtered
def _apply_groot_step_overrides(
pipeline: PolicyProcessorPipeline,
overrides: dict[str, Any] | None,
@@ -460,7 +496,8 @@ def _apply_groot_step_overrides(
steps by registry name only — prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently.
silently (standard normalization keys GR00T has no step for are removed
beforehand by ``_drop_groot_absent_standard_overrides``).
"""
if not overrides:
@@ -518,6 +555,12 @@ def make_groot_pre_post_processors_from_pretrained(
]:
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
# paths raise. This must happen before either branch below.
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
processor_cfg.base_model_path = str(pretrained_path)