From 90d1e70da2abf1141a8185771d963534657fff3c Mon Sep 17 00:00:00 2001 From: nv-sachdevkartik Date: Fri, 5 Jun 2026 00:11:37 +0000 Subject: [PATCH] removed remaining N1.5 traces --- .../policies/groot/configuration_groot.py | 10 +++++ tests/policies/groot/test_groot_n1_7.py | 42 +++++++++++-------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index ed93a8c0b..02f912a45 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -24,6 +24,11 @@ from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE GROOT_N1_7 = "n1.7" +# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is +# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version +# still rejects it). It is retained only so that infer_groot_model_version can recognise +# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch. +GROOT_N1_5 = "n1.5" GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" @@ -74,6 +79,11 @@ def infer_groot_model_version(model_path: str | None) -> str | None: model_path_lower = model_path.lower() if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower: return GROOT_N1_7 + # Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch. + # N1.5 is unsupported, but it must still be recognised here to fail loudly + # rather than silently treating an N1.5 checkpoint as N1.7. + if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower: + return GROOT_N1_5 config_version = _infer_groot_model_version_from_local_config(model_path) if config_version is not None: return config_version diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index c2422433b..918162e18 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -506,17 +506,20 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): model_path = tmp_path / "GR00T-N1.7-local" model_path.mkdir() input_features, output_features = _groot_features(state_dim=8, action_dim=7) - config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", - input_features=input_features, - output_features=output_features, - device="cpu", - use_bf16=False, - action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, - ) + # An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be + # rejected. The mismatch is detected during config validation (__post_init__), + # so construction itself raises before from_pretrained is reached. with pytest.raises(ValueError, match="does not match base_model_path"): + config = GrootConfig( + model_version=GROOT_N1_7, + base_model_path="nvidia/GR00T-N1.5-3B", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + ) GrootPolicy.from_pretrained(model_path, config=config) @@ -1238,17 +1241,20 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config( model_path.mkdir() (model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}') input_features, output_features = _groot_features(state_dim=8, action_dim=7) - config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", - input_features=input_features, - output_features=output_features, - device="cpu", - use_bf16=False, - action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, - ) + # An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be + # rejected. The mismatch is detected during config validation (__post_init__), + # so construction itself raises before from_pretrained is reached. with pytest.raises(ValueError, match="does not match base_model_path"): + config = GrootConfig( + model_version=GROOT_N1_7, + base_model_path="nvidia/GR00T-N1.5-3B", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + ) GrootPolicy.from_pretrained(model_path, config=config)