diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 50f3e0f27..ab40b7eda 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -18,7 +18,6 @@ from __future__ import annotations import importlib import logging -from copy import copy from typing import TYPE_CHECKING, Any, TypedDict, Unpack import torch @@ -49,7 +48,7 @@ from .act.configuration_act import ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig -from .groot.configuration_groot import GROOT_N1_7, GrootConfig +from .groot.configuration_groot import GrootConfig from .molmoact2.configuration_molmoact2 import MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config @@ -275,48 +274,21 @@ def make_pre_post_processors( """ if pretrained_path: if isinstance(policy_cfg, GrootConfig): - from .groot.configuration_groot import is_raw_groot_n1_7_checkpoint + from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained - if is_raw_groot_n1_7_checkpoint(pretrained_path): - from .groot.processor_groot import make_groot_pre_post_processors - - processor_cfg = copy(policy_cfg) - processor_cfg.base_model_path = str(pretrained_path) - return make_groot_pre_post_processors( - config=processor_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - - # TODO(Steven): Temporary patch, implement correctly the processors for Gr00t - if isinstance(policy_cfg, GrootConfig): - # GROOT handles normalization in its pack-inputs step - # Need to override both stats AND normalize_min_max since saved config might be empty - dataset_stats = kwargs.get("dataset_stats") - preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {})) - postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {})) - pack_inputs_key = ( - "groot_n1_7_pack_inputs_v1" - if policy_cfg.model_version == GROOT_N1_7 - else "groot_pack_inputs_v3" + return make_groot_pre_post_processors_from_pretrained( + config=policy_cfg, + pretrained_path=pretrained_path, + dataset_stats=kwargs.get("dataset_stats"), + preprocessor_overrides=kwargs.get("preprocessor_overrides"), + postprocessor_overrides=kwargs.get("postprocessor_overrides"), + preprocessor_config_filename=kwargs.get( + "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" + ), + postprocessor_config_filename=kwargs.get( + "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" + ), ) - pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {})) - pack_input_overrides["normalize_min_max"] = True - if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7: - pack_input_overrides["stats"] = dataset_stats - preprocessor_overrides[pack_inputs_key] = pack_input_overrides - - # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats - env_action_dim = policy_cfg.output_features[ACTION].shape[0] - action_unpack_overrides = dict( - postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}) - ) - action_unpack_overrides["normalize_min_max"] = True - action_unpack_overrides["env_action_dim"] = env_action_dim - if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7: - action_unpack_overrides["stats"] = dataset_stats - postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides - kwargs["preprocessor_overrides"] = preprocessor_overrides - kwargs["postprocessor_overrides"] = postprocessor_overrides preprocessor = PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 94f895c51..03e0f7877 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -15,6 +15,7 @@ # limitations under the License. import json +from copy import copy from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -45,7 +46,9 @@ from lerobot.processor import ( ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, + batch_to_transition, policy_action_to_transition, + transition_to_batch, transition_to_policy_action, ) from lerobot.types import EnvTransition, TransitionKey @@ -457,6 +460,86 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool: return any(bool(modality_stats) for modality_stats in stats.values()) +def _legacy_groot_processor_overrides( + config: GrootConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None, + preprocessor_overrides: dict[str, Any] | None = None, + postprocessor_overrides: dict[str, Any] | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Patch older serialized Groot processors with fields current processors expect.""" + + preprocessor_overrides = dict(preprocessor_overrides or {}) + postprocessor_overrides = dict(postprocessor_overrides or {}) + pack_inputs_key = ( + "groot_n1_7_pack_inputs_v1" if config.model_version == GROOT_N1_7 else "groot_pack_inputs_v3" + ) + + pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {})) + pack_input_overrides["normalize_min_max"] = True + if dataset_stats is not None and config.model_version != GROOT_N1_7: + pack_input_overrides["stats"] = dataset_stats + preprocessor_overrides[pack_inputs_key] = pack_input_overrides + + try: + env_action_dim = int(config.output_features[ACTION].shape[0]) + except Exception: + env_action_dim = 0 + action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {})) + action_unpack_overrides["normalize_min_max"] = True + action_unpack_overrides["env_action_dim"] = env_action_dim + if dataset_stats is not None and config.model_version != GROOT_N1_7: + action_unpack_overrides["stats"] = dataset_stats + postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides + + return preprocessor_overrides, postprocessor_overrides + + +def make_groot_pre_post_processors_from_pretrained( + config: GrootConfig, + pretrained_path: str, + *, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + preprocessor_overrides: dict[str, Any] | None = None, + postprocessor_overrides: dict[str, Any] | None = None, + preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json", + postprocessor_config_filename: str = f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json", +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Load Groot processors while preserving compatibility with older serialized configs.""" + + if is_raw_groot_n1_7_checkpoint(pretrained_path): + processor_cfg = copy(config) + processor_cfg.base_model_path = str(pretrained_path) + return make_groot_pre_post_processors( + config=processor_cfg, + dataset_stats=dataset_stats, + ) + + preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides( + config=config, + dataset_stats=dataset_stats, + preprocessor_overrides=preprocessor_overrides, + postprocessor_overrides=postprocessor_overrides, + ) + preprocessor = PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=preprocessor_config_filename, + overrides=preprocessor_overrides, + to_transition=batch_to_transition, + to_output=transition_to_batch, + ) + postprocessor = PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=postprocessor_config_filename, + overrides=postprocessor_overrides, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ) + return preprocessor, postprocessor + + def make_groot_pre_post_processors( config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[ diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 34c966e66..41b6ab8e1 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -1475,6 +1475,69 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat assert unpack_step.env_action_dim == 7 +def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path): + config = _groot_config(GROOT_N1_5) + dataset_stats = { + OBS_STATE: { + "min": torch.full((8,), -1.0), + "max": torch.full((8,), 1.0), + }, + ACTION: { + "min": torch.full((7,), -2.0), + "max": torch.full((7,), 2.0), + }, + } + legacy_preprocessor_config = { + "name": "policy_preprocessor", + "steps": [ + { + "registry_name": "groot_pack_inputs_v3", + "config": { + "state_horizon": 1, + "action_horizon": 16, + "max_state_dim": config.max_state_dim, + "max_action_dim": config.max_action_dim, + "language_key": "task", + "formalize_language": False, + "embodiment_tag": config.embodiment_tag, + "embodiment_mapping": {"new_embodiment": 31}, + "normalize_min_max": False, + }, + } + ], + } + legacy_postprocessor_config = { + "name": "policy_postprocessor", + "steps": [ + { + "registry_name": "groot_action_unpack_unnormalize_v1", + "config": { + "env_action_dim": 0, + "normalize_min_max": False, + }, + } + ], + } + (tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config)) + (tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config)) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(tmp_path), + dataset_stats=dataset_stats, + ) + + pack_step = loaded_preprocessor.steps[0] + unpack_step = loaded_postprocessor.steps[0] + assert pack_step.normalize_min_max + assert unpack_step.normalize_min_max + assert unpack_step.env_action_dim == 7 + torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"]) + torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"]) + torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"]) + torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"]) + + def test_groot_policy_selects_n1_7_model_class(monkeypatch): from lerobot.policies.groot.groot_n1_7 import GR00TN17