removed remaining N1.5 traces

This commit is contained in:
nv-sachdevkartik
2026-06-05 00:11:37 +00:00
parent a35ac22afd
commit 90d1e70da2
2 changed files with 34 additions and 18 deletions
@@ -24,6 +24,11 @@ from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
@@ -74,6 +79,11 @@ def infer_groot_model_version(model_path: str | None) -> str | None:
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
# N1.5 is unsupported, but it must still be recognised here to fail loudly
# rather than silently treating an N1.5 checkpoint as N1.7.
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
+24 -18
View File
@@ -506,17 +506,20 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
# An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be
# rejected. The mismatch is detected during config validation (__post_init__),
# so construction itself raises before from_pretrained is reached.
with pytest.raises(ValueError, match="does not match base_model_path"):
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
GrootPolicy.from_pretrained(model_path, config=config)
@@ -1238,17 +1241,20 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
model_path.mkdir()
(model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}')
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
# An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be
# rejected. The mismatch is detected during config validation (__post_init__),
# so construction itself raises before from_pretrained is reached.
with pytest.raises(ValueError, match="does not match base_model_path"):
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
GrootPolicy.from_pretrained(model_path, config=config)