Move Groot processor compatibility into Groot loader

This commit is contained in:
Andrew Wrenn
2026-06-02 13:19:12 -07:00
committed by Andy Wrenn
parent 9c26e111d1
commit 111dceeb8a
3 changed files with 160 additions and 42 deletions
+63
View File
@@ -1475,6 +1475,69 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
assert unpack_step.env_action_dim == 7
def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path):
config = _groot_config(GROOT_N1_5)
dataset_stats = {
OBS_STATE: {
"min": torch.full((8,), -1.0),
"max": torch.full((8,), 1.0),
},
ACTION: {
"min": torch.full((7,), -2.0),
"max": torch.full((7,), 2.0),
},
}
legacy_preprocessor_config = {
"name": "policy_preprocessor",
"steps": [
{
"registry_name": "groot_pack_inputs_v3",
"config": {
"state_horizon": 1,
"action_horizon": 16,
"max_state_dim": config.max_state_dim,
"max_action_dim": config.max_action_dim,
"language_key": "task",
"formalize_language": False,
"embodiment_tag": config.embodiment_tag,
"embodiment_mapping": {"new_embodiment": 31},
"normalize_min_max": False,
},
}
],
}
legacy_postprocessor_config = {
"name": "policy_postprocessor",
"steps": [
{
"registry_name": "groot_action_unpack_unnormalize_v1",
"config": {
"env_action_dim": 0,
"normalize_min_max": False,
},
}
],
}
(tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config))
(tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config))
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(tmp_path),
dataset_stats=dataset_stats,
)
pack_step = loaded_preprocessor.steps[0]
unpack_step = loaded_postprocessor.steps[0]
assert pack_step.normalize_min_max
assert unpack_step.normalize_min_max
assert unpack_step.env_action_dim == 7
torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17