Move Groot processor compatibility into Groot loader

This commit is contained in:
Andrew Wrenn
2026-06-02 13:19:12 -07:00
parent b568c41355
commit e3b203e5a7
3 changed files with 160 additions and 42 deletions
+14 -42
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
import importlib
import logging
from copy import copy
from typing import TYPE_CHECKING, Any, TypedDict, Unpack
import torch
@@ -49,7 +48,7 @@ from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GROOT_N1_7, GrootConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -275,48 +274,21 @@ def make_pre_post_processors(
"""
if pretrained_path:
if isinstance(policy_cfg, GrootConfig):
from .groot.configuration_groot import is_raw_groot_n1_7_checkpoint
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
if is_raw_groot_n1_7_checkpoint(pretrained_path):
from .groot.processor_groot import make_groot_pre_post_processors
processor_cfg = copy(policy_cfg)
processor_cfg.base_model_path = str(pretrained_path)
return make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in its pack-inputs step
# Need to override both stats AND normalize_min_max since saved config might be empty
dataset_stats = kwargs.get("dataset_stats")
preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {}))
postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {}))
pack_inputs_key = (
"groot_n1_7_pack_inputs_v1"
if policy_cfg.model_version == GROOT_N1_7
else "groot_pack_inputs_v3"
return make_groot_pre_post_processors_from_pretrained(
config=policy_cfg,
pretrained_path=pretrained_path,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
preprocessor_config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
postprocessor_config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
)
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
pack_input_overrides["stats"] = dataset_stats
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
action_unpack_overrides = dict(
postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {})
)
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
action_unpack_overrides["stats"] = dataset_stats
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -15,6 +15,7 @@
# limitations under the License.
import json
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -45,7 +46,9 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
transition_to_policy_action,
)
from lerobot.types import EnvTransition, TransitionKey
@@ -457,6 +460,86 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
def _legacy_groot_processor_overrides(
config: GrootConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Patch older serialized Groot processors with fields current processors expect."""
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = (
"groot_n1_7_pack_inputs_v1" if config.model_version == GROOT_N1_7 else "groot_pack_inputs_v3"
)
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
if dataset_stats is not None and config.model_version != GROOT_N1_7:
pack_input_overrides["stats"] = dataset_stats
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
if dataset_stats is not None and config.model_version != GROOT_N1_7:
action_unpack_overrides["stats"] = dataset_stats
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def make_groot_pre_post_processors_from_pretrained(
config: GrootConfig,
pretrained_path: str,
*,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
postprocessor_config_filename: str = f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Load Groot processors while preserving compatibility with older serialized configs."""
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
processor_cfg.base_model_path = str(pretrained_path)
return make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=dataset_stats,
)
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=preprocessor_config_filename,
overrides=preprocessor_overrides,
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=postprocessor_config_filename,
overrides=postprocessor_overrides,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return preprocessor, postprocessor
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
+63
View File
@@ -1475,6 +1475,69 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
assert unpack_step.env_action_dim == 7
def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path):
config = _groot_config(GROOT_N1_5)
dataset_stats = {
OBS_STATE: {
"min": torch.full((8,), -1.0),
"max": torch.full((8,), 1.0),
},
ACTION: {
"min": torch.full((7,), -2.0),
"max": torch.full((7,), 2.0),
},
}
legacy_preprocessor_config = {
"name": "policy_preprocessor",
"steps": [
{
"registry_name": "groot_pack_inputs_v3",
"config": {
"state_horizon": 1,
"action_horizon": 16,
"max_state_dim": config.max_state_dim,
"max_action_dim": config.max_action_dim,
"language_key": "task",
"formalize_language": False,
"embodiment_tag": config.embodiment_tag,
"embodiment_mapping": {"new_embodiment": 31},
"normalize_min_max": False,
},
}
],
}
legacy_postprocessor_config = {
"name": "policy_postprocessor",
"steps": [
{
"registry_name": "groot_action_unpack_unnormalize_v1",
"config": {
"env_action_dim": 0,
"normalize_min_max": False,
},
}
],
}
(tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config))
(tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config))
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(tmp_path),
dataset_stats=dataset_stats,
)
pack_step = loaded_preprocessor.steps[0]
unpack_step = loaded_postprocessor.steps[0]
assert pack_step.normalize_min_max
assert unpack_step.normalize_min_max
assert unpack_step.env_action_dim == 7
torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17