mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
removed remaining N1.5 traces
This commit is contained in:
@@ -24,6 +24,11 @@ from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
@@ -74,6 +79,11 @@ def infer_groot_model_version(model_path: str | None) -> str | None:
|
||||
model_path_lower = model_path.lower()
|
||||
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
|
||||
return GROOT_N1_7
|
||||
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
|
||||
# N1.5 is unsupported, but it must still be recognised here to fail loudly
|
||||
# rather than silently treating an N1.5 checkpoint as N1.7.
|
||||
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
|
||||
return GROOT_N1_5
|
||||
config_version = _infer_groot_model_version_from_local_config(model_path)
|
||||
if config_version is not None:
|
||||
return config_version
|
||||
|
||||
@@ -506,17 +506,20 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
model_path.mkdir()
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
)
|
||||
|
||||
# An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be
|
||||
# rejected. The mismatch is detected during config validation (__post_init__),
|
||||
# so construction itself raises before from_pretrained is reached.
|
||||
with pytest.raises(ValueError, match="does not match base_model_path"):
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
)
|
||||
GrootPolicy.from_pretrained(model_path, config=config)
|
||||
|
||||
|
||||
@@ -1238,17 +1241,20 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
|
||||
model_path.mkdir()
|
||||
(model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}')
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
)
|
||||
|
||||
# An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be
|
||||
# rejected. The mismatch is detected during config validation (__post_init__),
|
||||
# so construction itself raises before from_pretrained is reached.
|
||||
with pytest.raises(ValueError, match="does not match base_model_path"):
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
)
|
||||
GrootPolicy.from_pretrained(model_path, config=config)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user