mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Move Groot processor compatibility into Groot loader
This commit is contained in:
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
@@ -49,7 +48,7 @@ from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GROOT_N1_7, GrootConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -275,48 +274,21 @@ def make_pre_post_processors(
|
||||
"""
|
||||
if pretrained_path:
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.configuration_groot import is_raw_groot_n1_7_checkpoint
|
||||
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
|
||||
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
processor_cfg = copy(policy_cfg)
|
||||
processor_cfg.base_model_path = str(pretrained_path)
|
||||
return make_groot_pre_post_processors(
|
||||
config=processor_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in its pack-inputs step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
dataset_stats = kwargs.get("dataset_stats")
|
||||
preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {}))
|
||||
postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {}))
|
||||
pack_inputs_key = (
|
||||
"groot_n1_7_pack_inputs_v1"
|
||||
if policy_cfg.model_version == GROOT_N1_7
|
||||
else "groot_pack_inputs_v3"
|
||||
return make_groot_pre_post_processors_from_pretrained(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
postprocessor_config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
)
|
||||
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
||||
pack_input_overrides["normalize_min_max"] = True
|
||||
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
|
||||
pack_input_overrides["stats"] = dataset_stats
|
||||
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
action_unpack_overrides = dict(
|
||||
postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {})
|
||||
)
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
|
||||
action_unpack_overrides["stats"] = dataset_stats
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -45,7 +46,9 @@ from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
@@ -457,6 +460,86 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
def _legacy_groot_processor_overrides(
|
||||
config: GrootConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
|
||||
preprocessor_overrides: dict[str, Any] | None = None,
|
||||
postprocessor_overrides: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Patch older serialized Groot processors with fields current processors expect."""
|
||||
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
pack_inputs_key = (
|
||||
"groot_n1_7_pack_inputs_v1" if config.model_version == GROOT_N1_7 else "groot_pack_inputs_v3"
|
||||
)
|
||||
|
||||
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
||||
pack_input_overrides["normalize_min_max"] = True
|
||||
if dataset_stats is not None and config.model_version != GROOT_N1_7:
|
||||
pack_input_overrides["stats"] = dataset_stats
|
||||
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
||||
|
||||
try:
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
if dataset_stats is not None and config.model_version != GROOT_N1_7:
|
||||
action_unpack_overrides["stats"] = dataset_stats
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
|
||||
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
|
||||
|
||||
def make_groot_pre_post_processors_from_pretrained(
|
||||
config: GrootConfig,
|
||||
pretrained_path: str,
|
||||
*,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_overrides: dict[str, Any] | None = None,
|
||||
postprocessor_overrides: dict[str, Any] | None = None,
|
||||
preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
|
||||
postprocessor_config_filename: str = f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Load Groot processors while preserving compatibility with older serialized configs."""
|
||||
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
processor_cfg = copy(config)
|
||||
processor_cfg.base_model_path = str(pretrained_path)
|
||||
return make_groot_pre_post_processors(
|
||||
config=processor_cfg,
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=preprocessor_config_filename,
|
||||
overrides=preprocessor_overrides,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=postprocessor_config_filename,
|
||||
overrides=postprocessor_overrides,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
|
||||
@@ -1475,6 +1475,69 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
|
||||
assert unpack_step.env_action_dim == 7
|
||||
|
||||
|
||||
def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path):
|
||||
config = _groot_config(GROOT_N1_5)
|
||||
dataset_stats = {
|
||||
OBS_STATE: {
|
||||
"min": torch.full((8,), -1.0),
|
||||
"max": torch.full((8,), 1.0),
|
||||
},
|
||||
ACTION: {
|
||||
"min": torch.full((7,), -2.0),
|
||||
"max": torch.full((7,), 2.0),
|
||||
},
|
||||
}
|
||||
legacy_preprocessor_config = {
|
||||
"name": "policy_preprocessor",
|
||||
"steps": [
|
||||
{
|
||||
"registry_name": "groot_pack_inputs_v3",
|
||||
"config": {
|
||||
"state_horizon": 1,
|
||||
"action_horizon": 16,
|
||||
"max_state_dim": config.max_state_dim,
|
||||
"max_action_dim": config.max_action_dim,
|
||||
"language_key": "task",
|
||||
"formalize_language": False,
|
||||
"embodiment_tag": config.embodiment_tag,
|
||||
"embodiment_mapping": {"new_embodiment": 31},
|
||||
"normalize_min_max": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
legacy_postprocessor_config = {
|
||||
"name": "policy_postprocessor",
|
||||
"steps": [
|
||||
{
|
||||
"registry_name": "groot_action_unpack_unnormalize_v1",
|
||||
"config": {
|
||||
"env_action_dim": 0,
|
||||
"normalize_min_max": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
(tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config))
|
||||
(tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config))
|
||||
|
||||
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
|
||||
config,
|
||||
pretrained_path=str(tmp_path),
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
|
||||
pack_step = loaded_preprocessor.steps[0]
|
||||
unpack_step = loaded_postprocessor.steps[0]
|
||||
assert pack_step.normalize_min_max
|
||||
assert unpack_step.normalize_min_max
|
||||
assert unpack_step.env_action_dim == 7
|
||||
torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
|
||||
torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
|
||||
torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
|
||||
torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
|
||||
Reference in New Issue
Block a user