mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
Move Groot processor compatibility into Groot loader
This commit is contained in:
@@ -1475,6 +1475,69 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
|
||||
assert unpack_step.env_action_dim == 7
|
||||
|
||||
|
||||
def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path):
|
||||
config = _groot_config(GROOT_N1_5)
|
||||
dataset_stats = {
|
||||
OBS_STATE: {
|
||||
"min": torch.full((8,), -1.0),
|
||||
"max": torch.full((8,), 1.0),
|
||||
},
|
||||
ACTION: {
|
||||
"min": torch.full((7,), -2.0),
|
||||
"max": torch.full((7,), 2.0),
|
||||
},
|
||||
}
|
||||
legacy_preprocessor_config = {
|
||||
"name": "policy_preprocessor",
|
||||
"steps": [
|
||||
{
|
||||
"registry_name": "groot_pack_inputs_v3",
|
||||
"config": {
|
||||
"state_horizon": 1,
|
||||
"action_horizon": 16,
|
||||
"max_state_dim": config.max_state_dim,
|
||||
"max_action_dim": config.max_action_dim,
|
||||
"language_key": "task",
|
||||
"formalize_language": False,
|
||||
"embodiment_tag": config.embodiment_tag,
|
||||
"embodiment_mapping": {"new_embodiment": 31},
|
||||
"normalize_min_max": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
legacy_postprocessor_config = {
|
||||
"name": "policy_postprocessor",
|
||||
"steps": [
|
||||
{
|
||||
"registry_name": "groot_action_unpack_unnormalize_v1",
|
||||
"config": {
|
||||
"env_action_dim": 0,
|
||||
"normalize_min_max": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
(tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config))
|
||||
(tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config))
|
||||
|
||||
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
|
||||
config,
|
||||
pretrained_path=str(tmp_path),
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
|
||||
pack_step = loaded_preprocessor.steps[0]
|
||||
unpack_step = loaded_postprocessor.steps[0]
|
||||
assert pack_step.normalize_min_max
|
||||
assert unpack_step.normalize_min_max
|
||||
assert unpack_step.env_action_dim == 7
|
||||
torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
|
||||
torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
|
||||
torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
|
||||
torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
|
||||
Reference in New Issue
Block a user