From bf9877fa0ba82d2b60bbb9873e2076fe6a35800f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 12 Jun 2026 23:38:08 +0200 Subject: [PATCH] test(groot): regression coverage and CI guards for the N1.7 review fixes Adds/updates unit tests for the N1.5 removal surfaces (config, checkpoint markers, removed processor steps, v2 action-unpack registration), the legacy-default remap warnings, action_decode_transform auto/none resolution, 2-D action decoding, the per-instance raw-state cache and pack/decode reconnection, raw-checkpoint stats fallback and override handling, camera-match warnings, bf16 handling, and backbone loading kwargs. Adds pytest.importorskip guards so the fast_tests tiers pass without transformers, and asserts the training forward pass returns a finite loss. Note: these tests exercise symbols introduced by the GR00T N1.7 source PRs (source-core, backbone); merge those for green CI on this branch. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/policies/groot/test_groot_lerobot.py | 5 + tests/policies/groot/test_groot_n1_7.py | 966 ++++++++++++++++++++- 2 files changed, 928 insertions(+), 43 deletions(-) diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 34acdef2f..3ddd6bf33 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass(): with torch.no_grad(): lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed) + assert isinstance(lerobot_loss, torch.Tensor) + assert torch.isfinite(lerobot_loss).all() + assert "loss" in lerobot_metrics + assert np.isfinite(lerobot_metrics["loss"]) + print("\nForward pass successful.") print(f" - Loss: {lerobot_loss.item():.6f}") print(f" - Metrics: {lerobot_metrics}") diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 5f06bb73e..4d32eb841 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -16,6 +16,8 @@ import inspect import json +import logging +import re import sys from types import SimpleNamespace from unittest.mock import patch @@ -23,15 +25,19 @@ from unittest.mock import patch import numpy as np import pytest import torch +from draccus.utils import ParsingError from torch import nn -from lerobot.configs import FeatureType, PolicyFeature +from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig from lerobot.policies.factory import make_policy_config, make_pre_post_processors from lerobot.policies.groot.configuration_groot import ( GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + GROOT_N1_5, + GROOT_N1_5_REMOVAL_GUIDANCE, GROOT_N1_7, GROOT_N1_7_BASE_MODEL, GrootConfig, + infer_groot_model_version, infer_groot_n1_7_action_execution_horizon, infer_groot_n1_7_action_horizon, ) @@ -47,7 +53,9 @@ from lerobot.policies.groot.processor_groot import ( from lerobot.processor import ( AbsoluteActionsProcessorStep, PolicyProcessorPipeline, + ProcessorStepRegistry, RelativeActionsProcessorStep, + RenameObservationsProcessorStep, ) from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE @@ -93,6 +101,7 @@ def _raw_n1_7_libero_config(model_path) -> GrootConfig: def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_ids(monkeypatch): + pytest.importorskip("transformers") from transformers.feature_extraction_utils import BatchFeature import lerobot.policies.groot.groot_n1_7 as groot_n1_7 @@ -330,8 +339,9 @@ class _DummyGrootModel(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(())) - self.config = SimpleNamespace(compute_dtype="float32") - self.compute_dtype = "float32" + # Like the real GR00TN17, the dummy defines no compute_dtype attribute: + # GrootPolicy only sets it when use_bf16 is enabled. + self.config = SimpleNamespace() self.forward_inputs = None self.get_action_options = None @@ -378,20 +388,31 @@ def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy ) -@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"]) -def test_groot_rejects_n1_5_aliases(legacy_version): - with pytest.raises(ValueError, match="Unsupported GR00T model_version"): +@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n1d5", "n15", "1.5"]) +def test_groot_rejects_n1_5_aliases_with_removal_guidance(legacy_version): + with pytest.raises(ValueError, match="Unsupported GR00T model_version") as exc_info: GrootConfig(model_version=legacy_version, device="cpu") + assert GROOT_N1_5_REMOVAL_GUIDANCE in str(exc_info.value) + + +def test_groot_rejected_non_n1_5_version_omits_removal_guidance(): + with pytest.raises(ValueError, match="Unsupported GR00T model_version") as exc_info: + GrootConfig(model_version="n2.0", device="cpu") + + assert GROOT_N1_5_REMOVAL_GUIDANCE not in str(exc_info.value) + def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7(): - with pytest.raises(ValueError, match="does not match base_model_path"): + with pytest.raises(ValueError, match="does not match base_model_path") as exc_info: GrootConfig( model_version=GROOT_N1_7, base_model_path="nvidia/GR00T-N1.5-3B", device="cpu", ) + assert GROOT_N1_5_REMOVAL_GUIDANCE in str(exc_info.value) + def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_gr00t(): sys.modules.pop("gr00t", None) @@ -411,6 +432,7 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs(): def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 dummy_model = _DummyGrootModel() @@ -440,6 +462,7 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 dummy_model = _DummyGrootModel() @@ -473,6 +496,7 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch): def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 model_path = tmp_path / "libero_spatial" @@ -506,28 +530,29 @@ def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(t torch.testing.assert_close(actions[0, :, 0], torch.arange(16, dtype=torch.float32)) -def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): - model_path = tmp_path / "GR00T-N1.7-local" +def _write_n1_5_marked_checkpoint(model_path): + """Write a generically named local dir whose config.json carries N1.5 content markers.""" model_path.mkdir() - input_features, output_features = _groot_features(state_dim=8, action_dim=7) + (model_path / "config.json").write_text( + json.dumps({"model_type": "gr00t_n1_5", "architectures": ["GR00T_N1_5"]}) + ) - # An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be - # rejected. The mismatch is detected during config validation (__post_init__), - # so construction itself raises before from_pretrained is reached. + +def test_groot_from_pretrained_rejects_n1_5_checkpoint_against_n1_7_caller_config(tmp_path): + model_path = tmp_path / "local-checkpoint" + _write_n1_5_marked_checkpoint(model_path) + config = _groot_config(GROOT_N1_7) + + # The caller config is valid on its own; from_pretrained overrides its + # base_model_path with the pretrained path, detects the N1.5 checkpoint from + # the local config.json content, and must reject the mismatch before any + # model weights are loaded. with pytest.raises(ValueError, match="does not match base_model_path"): - config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", - input_features=input_features, - output_features=output_features, - device="cpu", - use_bf16=False, - action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, - ) GrootPolicy.from_pretrained(model_path, config=config) def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 model_path = tmp_path / "GR00T-N1.7-local" @@ -543,6 +568,7 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path, monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 model_path = tmp_path / "local-checkpoint" @@ -1327,26 +1353,17 @@ def test_groot_n1_7_postprocessor_decodes_action_chunks_without_dropping_timeste torch.testing.assert_close(output[TransitionKey.ACTION][..., -1], torch.tensor([[1.0, -0.0, -1.0]])) -def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(tmp_path): +def test_groot_from_pretrained_rejects_n1_5_checkpoint_without_caller_config(tmp_path): model_path = tmp_path / "local-checkpoint" - model_path.mkdir() - (model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}') - input_features, output_features = _groot_features(state_dim=8, action_dim=7) + _write_n1_5_marked_checkpoint(model_path) - # An N1.7 config paired with a legacy N1.5 base path is a mismatch and must be - # rejected. The mismatch is detected during config validation (__post_init__), - # so construction itself raises before from_pretrained is reached. - with pytest.raises(ValueError, match="does not match base_model_path"): - config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", - input_features=input_features, - output_features=output_features, - device="cpu", - use_bf16=False, - action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, - ) - GrootPolicy.from_pretrained(model_path, config=config) + # Without a caller config, from_pretrained infers the model version from the + # local config.json content ('n1.5') and must fail with the removal guidance + # instead of silently treating the N1.5 checkpoint as N1.7. + with pytest.raises(ValueError, match="Unsupported GR00T model_version") as exc_info: + GrootPolicy.from_pretrained(model_path) + + assert GROOT_N1_5_REMOVAL_GUIDANCE in str(exc_info.value) def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t(): @@ -1673,7 +1690,7 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch): raise AssertionError("Cosmos does not publish processor_config.json") monkeypatch.setattr(transformers, "AutoTokenizer", FakeTokenizer) - monkeypatch.setattr(transformers, "Qwen2VLImageProcessorFast", FakeImageProcessor) + monkeypatch.setattr(transformers, "Qwen2VLImageProcessor", FakeImageProcessor) monkeypatch.setattr(transformers, "Qwen3VLVideoProcessor", FakeVideoProcessor) monkeypatch.setattr(transformers, "Qwen3VLProcessor", FakeProcessor) @@ -1753,9 +1770,8 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat assert unpack_step.env_action_dim == 7 - - def test_groot_policy_selects_n1_7_model_class(monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 called = {} @@ -1773,6 +1789,7 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch): def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 dummy_model = _DummyGrootModel() @@ -1823,6 +1840,7 @@ def test_groot_n1_7_libero_execution_horizon_uses_core_eight_action_cadence(tmp_ def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkeypatch): + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17 model_path = tmp_path / "libero_spatial" @@ -2016,6 +2034,7 @@ def test_qwen3_backbone_can_initialize_from_config_without_downloading_weights(m def test_gr00t_n1_7_from_pretrained_defers_backbone_weight_loading(monkeypatch, tmp_path): + pytest.importorskip("transformers") from huggingface_hub.errors import HFValidationError import lerobot.policies.groot.groot_n1_7 as groot_n1_7 @@ -2052,6 +2071,8 @@ def test_gr00t_n1_7_from_pretrained_defers_backbone_weight_loading(monkeypatch, def test_gr00t_n1_7_action_head_meta_init_defers_beta_distribution(): pytest.importorskip("diffusers") + # GR00TN17Config subclasses transformers.PretrainedConfig (object fallback otherwise). + pytest.importorskip("transformers") from lerobot.policies.groot.groot_n1_7 import GR00TN17ActionHead, GR00TN17Config @@ -2171,3 +2192,862 @@ def test_gr00t_n1_7_model_forward_with_mocked_backbone(): inference_inputs = {key: value for key, value in inputs.items() if key != "action"} action_output = model.get_action(inference_inputs) assert action_output["action_pred"].shape == (2, config.action_horizon, config.max_action_dim) + + +# --------------------------------------------------------------------------- +# GR00T N1.5 removal: every detection point must fail with the canonical guidance +# --------------------------------------------------------------------------- + + +def test_groot_config_rejects_legacy_n1_5_tokenizer_assets_repo(): + with pytest.raises(ValueError, match="tokenizer_assets_repo") as exc_info: + GrootConfig(tokenizer_assets_repo="nvidia/GR00T-N1.5-3B", device="cpu") + + assert GROOT_N1_5_REMOVAL_GUIDANCE in str(exc_info.value) + + +def test_groot_legacy_n1_5_checkpoint_config_fails_with_removal_guidance(tmp_path): + # config.json layout serialized by lerobot<=0.5.1 groot checkpoints: legacy + # N1.5 defaults plus the N1.5-only 'tokenizer_assets_repo' field. + legacy_config = { + "type": "groot", + "n_obs_steps": 1, + "chunk_size": 50, + "n_action_steps": 50, + "max_state_dim": 64, + "max_action_dim": 32, + "model_version": "n1.5", + "base_model_path": "nvidia/GR00T-N1.5-3B", + "tokenizer_assets_repo": "nvidia/GR00T-N1.5-3B", + "embodiment_tag": "gr1", + "video_backend": "decord", + "output_dir": "./tmp/gr00t", + "device": "cpu", + } + (tmp_path / "config.json").write_text(json.dumps(legacy_config)) + + with pytest.raises(ParsingError) as exc_info: + PreTrainedConfig.from_pretrained(tmp_path) + + # draccus wraps the dataclass error in a generic ParsingError; the clear N1.5 + # removal message must be the root cause instead of an opaque DecodingError + # about unknown config fields. + messages = [] + error: BaseException | None = exc_info.value + while error is not None: + messages.append(str(error)) + error = error.__cause__ or error.__context__ + assert any( + "tokenizer_assets_repo" in message and GROOT_N1_5_REMOVAL_GUIDANCE in message for message in messages + ) + + +@pytest.mark.parametrize( + "config_payload", + [ + {"model_type": "gr00t_n1_5"}, + {"architectures": ["GR00T_N1_5"]}, + {"model_version": "n1_5"}, + {"backbone_cfg": {"eagle_path": "eagle"}}, + ], +) +def test_groot_config_rejects_generic_local_dir_with_n1_5_content_markers(tmp_path, config_payload): + # A renamed local snapshot has no N1.5 hint in its path; the config.json + # content markers (as shipped by nvidia/GR00T-N1.5-3B) must still be detected. + model_path = tmp_path / "renamed-snapshot" + model_path.mkdir() + (model_path / "config.json").write_text(json.dumps(config_payload)) + + assert infer_groot_model_version(str(model_path)) == GROOT_N1_5 + + with pytest.raises(ValueError, match="does not match base_model_path") as exc_info: + GrootConfig(base_model_path=str(model_path), device="cpu") + + assert GROOT_N1_5_REMOVAL_GUIDANCE in str(exc_info.value) + + +@pytest.mark.parametrize( + "registry_name", + [ + "groot_pack_inputs_v3", + "groot_eagle_encode_v3", + "groot_eagle_collate_v3", + "groot_action_unpack_unnormalize_v1", + ], +) +def test_removed_n1_5_processor_steps_fail_with_removal_guidance(tmp_path, registry_name): + (tmp_path / "processor.json").write_text( + json.dumps( + {"name": "legacy_groot_processor", "steps": [{"registry_name": registry_name, "config": {}}]} + ) + ) + + with pytest.raises(ValueError, match=re.escape(GROOT_N1_5_REMOVAL_GUIDANCE)): + PolicyProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json") + + +def test_groot_action_unpack_step_registers_and_serializes_as_v2(tmp_path): + # The action-chunk semantics changed vs. the N1.5-era v1 step, so the registry + # name was bumped: v1 must never silently load into the new implementation. + assert ProcessorStepRegistry.get("groot_action_unpack_unnormalize_v2") is GrootActionUnpackUnnormalizeStep + + config = _groot_config(GROOT_N1_7) + dataset_stats = { + OBS_STATE: {"min": torch.zeros(8), "max": torch.ones(8)}, + ACTION: {"min": torch.zeros(7), "max": torch.ones(7)}, + } + _, postprocessor = make_groot_pre_post_processors(config, dataset_stats=dataset_stats) + postprocessor.save_pretrained(tmp_path) + + saved = json.loads((tmp_path / "policy_postprocessor.json").read_text()) + assert saved["steps"][0]["registry_name"] == "groot_action_unpack_unnormalize_v2" + + +# --------------------------------------------------------------------------- +# Legacy N1.5-era default remapping warns instead of silently rewriting values +# --------------------------------------------------------------------------- + + +def test_groot_default_config_uses_n1_7_values_without_warnings(caplog): + with caplog.at_level(logging.WARNING, logger="lerobot.policies.groot.configuration_groot"): + config = GrootConfig(device="cpu") + + assert config.max_state_dim == 132 + assert config.max_action_dim == 132 + assert config.chunk_size == 40 + assert config.n_action_steps == 40 + assert tuple(config.image_size) == (256, 256) + assert not any("legacy GR00T N1.5-era default" in record.getMessage() for record in caplog.records) + + +def test_groot_legacy_default_remap_emits_warnings(caplog): + with caplog.at_level(logging.WARNING, logger="lerobot.policies.groot.configuration_groot"): + config = GrootConfig( + chunk_size=50, + n_action_steps=50, + max_state_dim=64, + max_action_dim=32, + image_size=(224, 224), + device="cpu", + ) + + assert config.max_state_dim == 132 + assert config.max_action_dim == 132 + assert config.chunk_size == 40 + assert config.n_action_steps == 40 + assert tuple(config.image_size) == (256, 256) + remap_messages = [ + record.getMessage() + for record in caplog.records + if "legacy GR00T N1.5-era default" in record.getMessage() + ] + assert any("chunk_size=50" in message and "40" in message for message in remap_messages) + assert any("n_action_steps=50" in message and "40" in message for message in remap_messages) + assert any("max_state_dim=64" in message and "132" in message for message in remap_messages) + assert any("max_action_dim=32" in message and "132" in message for message in remap_messages) + assert any("image_size=(224, 224)" in message and "(256, 256)" in message for message in remap_messages) + + +def test_groot_non_legacy_values_are_not_remapped(caplog): + with caplog.at_level(logging.WARNING, logger="lerobot.policies.groot.configuration_groot"): + config = GrootConfig( + chunk_size=48, + n_action_steps=20, + max_state_dim=100, + max_action_dim=65, + image_size=(225, 225), + device="cpu", + ) + + assert config.chunk_size == 48 + assert config.n_action_steps == 20 + assert config.max_state_dim == 100 + assert config.max_action_dim == 65 + assert tuple(config.image_size) == (225, 225) + assert not any("legacy GR00T N1.5-era default" in record.getMessage() for record in caplog.records) + + +# --------------------------------------------------------------------------- +# action_decode_transform: explicit 'none' wins over the embodiment default +# --------------------------------------------------------------------------- + + +def test_groot_explicit_none_action_decode_transform_overrides_libero_default(): + config = GrootConfig(embodiment_tag="libero_sim", action_decode_transform="none", device="cpu") + + assert config.action_decode_transform is None + + +@pytest.mark.parametrize("auto_value", ["auto", "AUTO"]) +def test_groot_auto_action_decode_transform_resolves_to_embodiment_default(auto_value): + libero = GrootConfig(embodiment_tag="libero_sim", action_decode_transform=auto_value, device="cpu") + other = GrootConfig(embodiment_tag="new_embodiment", action_decode_transform=auto_value, device="cpu") + + assert libero.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + assert other.action_decode_transform is None + + +def test_groot_action_decode_transform_opt_out_survives_save_load_round_trip(tmp_path): + explicit_none = GrootConfig(embodiment_tag="libero_sim", action_decode_transform="none", device="cpu") + explicit_dir = tmp_path / "explicit_none" + explicit_none.save_pretrained(explicit_dir) + reloaded_none = PreTrainedConfig.from_pretrained(explicit_dir) + + assert reloaded_none.action_decode_transform is None + + unset = GrootConfig(embodiment_tag="libero_sim", device="cpu") + unset_dir = tmp_path / "unset" + unset.save_pretrained(unset_dir) + reloaded_unset = PreTrainedConfig.from_pretrained(unset_dir) + + assert reloaded_unset.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + + +# --------------------------------------------------------------------------- +# GrootN17ActionDecodeStep: 2-D (B, D) actions from the sync select_action path +# --------------------------------------------------------------------------- + + +def _symmetric_unit_stats(dim: int) -> dict[str, list[float]]: + return { + "min": [-1.0] * dim, + "max": [1.0] * dim, + "mean": [0.0] * dim, + "std": [1.0] * dim, + "q01": [-1.0] * dim, + "q99": [1.0] * dim, + } + + +def test_groot_n1_7_action_decode_handles_2d_relative_non_eef_actions(): + raw_stats = { + "state": {"single_arm": _symmetric_unit_stats(5)}, + "action": {"single_arm": _symmetric_unit_stats(5)}, + "relative_action": {"single_arm": {"min": [-1.0] * 5, "max": [1.0] * 5}}, + } + modality_config = { + "state": {"modality_keys": ["single_arm"]}, + "action": { + "modality_keys": ["single_arm"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None} + ], + }, + } + pack_step = GrootN17PackInputsStep( + raw_stats=raw_stats, modality_config=modality_config, normalize_min_max=False, max_state_dim=8 + ) + reference = torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0, 9.0]]) + pack_step({TransitionKey.OBSERVATION: {OBS_STATE: reference}, TransitionKey.COMPLEMENTARY_DATA: {}}) + decode_step = GrootN17ActionDecodeStep( + env_action_dim=5, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack_step, + ) + + output = decode_step({TransitionKey.ACTION: torch.zeros(2, 5)}) + + # Relative stats span [-1, 1], so the normalized 0 decodes to a 0 delta and the + # absolute action equals the cached reference state, preserving the (B, D) rank. + assert output[TransitionKey.ACTION].shape == (2, 5) + torch.testing.assert_close(output[TransitionKey.ACTION], reference) + + # 3-D chunks keep their (B, T, D) rank and decode identically per timestep. + chunk_output = decode_step({TransitionKey.ACTION: torch.zeros(2, 3, 5)}) + assert chunk_output[TransitionKey.ACTION].shape == (2, 3, 5) + torch.testing.assert_close(chunk_output[TransitionKey.ACTION], reference[:, None, :].expand(2, 3, 5)) + + +def test_groot_n1_7_action_decode_handles_2d_mixed_relative_and_absolute_groups(): + raw_stats = { + "state": {"single_arm": {"mean": [0.0] * 5}, "gripper": {"mean": [0.0]}}, + "action": { + "single_arm": _symmetric_unit_stats(5), + "gripper": { + "min": [0.0], + "max": [10.0], + "mean": [5.0], + "std": [1.0], + "q01": [0.0], + "q99": [10.0], + }, + }, + "relative_action": {"single_arm": {"min": [-1.0] * 5, "max": [1.0] * 5}}, + } + modality_config = { + "state": {"modality_keys": ["single_arm", "gripper"]}, + "action": { + "modality_keys": ["single_arm", "gripper"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + {"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}, + ], + }, + } + pack_step = GrootN17PackInputsStep( + raw_stats=raw_stats, modality_config=modality_config, normalize_min_max=False, max_state_dim=8 + ) + pack_step( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0, 9.0]])}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + ) + decode_step = GrootN17ActionDecodeStep( + env_action_dim=6, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack_step, + ) + + output = decode_step({TransitionKey.ACTION: torch.zeros(1, 6)}) + + # Arm group: 0 delta added to the cached reference state [0..4]. Gripper group is + # absolute: a normalized 0 unnormalizes to the [0, 10] midpoint 5.0 (not the raw + # reference state 9.0). + assert output[TransitionKey.ACTION].shape == (1, 6) + torch.testing.assert_close(output[TransitionKey.ACTION], torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]])) + + +def test_groot_n1_7_action_decode_handles_2d_relative_eef_xyz_rot6d_actions(): + raw_stats = { + "state": {"eef": {"mean": [0.0] * 9}}, + "action": {"eef": _symmetric_unit_stats(9)}, + "relative_action": {"eef": {"min": [-1.0] * 9, "max": [1.0] * 9}}, + } + modality_config = { + "state": {"modality_keys": ["eef"]}, + "action": { + "modality_keys": ["eef"], + "action_configs": [{"rep": "RELATIVE", "type": "EEF", "format": "XYZ+ROT6D", "state_key": None}], + }, + } + pack_step = GrootN17PackInputsStep( + raw_stats=raw_stats, modality_config=modality_config, normalize_min_max=False, max_state_dim=16 + ) + # Reference pose: translation (1, 2, 3) with the identity rotation in rot6d form. + identity_rot6d = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0] + pack_step( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[1.0, 2.0, 3.0, *identity_rot6d]])}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + ) + decode_step = GrootN17ActionDecodeStep( + env_action_dim=9, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack_step, + ) + + # Relative stats span [-1, 1], so the normalized values below ARE the decoded + # deltas: translate +0.5 along x with no rotation change. + output = decode_step({TransitionKey.ACTION: torch.tensor([[0.5, 0.0, 0.0, *identity_rot6d]])}) + + assert output[TransitionKey.ACTION].shape == (1, 9) + torch.testing.assert_close(output[TransitionKey.ACTION], torch.tensor([[1.5, 2.0, 3.0, *identity_rot6d]])) + + +def test_groot_n1_7_action_decode_uses_first_stat_row_for_2d_per_timestep_relative_stats(): + per_step_min = [[-float(step + 1)] * 5 for step in range(16)] + per_step_max = [[float(step + 1)] * 5 for step in range(16)] + raw_stats = { + "state": {"single_arm": {"mean": [0.0] * 5}}, + "action": {"single_arm": _symmetric_unit_stats(5)}, + "relative_action": {"single_arm": {"min": per_step_min, "max": per_step_max}}, + } + modality_config = { + "state": {"modality_keys": ["single_arm"]}, + "action": { + "delta_indices": list(range(16)), + "modality_keys": ["single_arm"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None} + ], + }, + } + pack_step = GrootN17PackInputsStep( + raw_stats=raw_stats, modality_config=modality_config, normalize_min_max=False, max_state_dim=8 + ) + pack_step( + {TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 5)}, TransitionKey.COMPLEMENTARY_DATA: {}} + ) + decode_step = GrootN17ActionDecodeStep( + env_action_dim=5, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack_step, + ) + + output = decode_step({TransitionKey.ACTION: torch.full((1, 5), 0.5)}) + + # A popped (B, D) action decodes as chunk step 0: row 0 of the per-timestep stats + # spans [-1, 1], so 0.5 unnormalizes to 0.5 (row 1 would give 1.0) and the zero + # reference leaves it unchanged. + assert output[TransitionKey.ACTION].shape == (1, 5) + torch.testing.assert_close(output[TransitionKey.ACTION], torch.full((1, 5), 0.5)) + + +# --------------------------------------------------------------------------- +# Raw checkpoint stats fallback and per-instance raw-state cache +# --------------------------------------------------------------------------- + + +def _raw_n1_7_unknown_embodiment_config(model_path) -> GrootConfig: + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + return GrootConfig( + model_version=GROOT_N1_7, + base_model_path=str(model_path), + embodiment_tag="new_embodiment", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + ) + + +def test_raw_n1_7_checkpoint_missing_embodiment_stats_falls_back_to_dataset_stats(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_unknown_embodiment_config(model_path) + dataset_stats = { + OBS_STATE: {"min": torch.zeros(8), "max": torch.full((8,), 30.0)}, + ACTION: {"min": torch.zeros(7), "max": torch.full((7,), 30.0)}, + } + + _, postprocessor = make_groot_pre_post_processors(config, dataset_stats=dataset_stats) + + assert isinstance(postprocessor.steps[0], GrootActionUnpackUnnormalizeStep) + # The decode must invert the dataset-stats normalization applied by the pack step: + # a normalized 0 decodes to the [0, 30] midpoint 15.0 instead of passing through. + decoded = postprocessor(torch.zeros(2, 7)) + torch.testing.assert_close(decoded, torch.full((2, 7), 15.0)) + + +def test_raw_n1_7_checkpoint_missing_embodiment_stats_without_dataset_stats_raises(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_unknown_embodiment_config(model_path) + + with pytest.raises(ValueError, match="has no statistics for embodiment tag"): + make_groot_pre_post_processors(config, dataset_stats=None) + + +def test_groot_n1_7_raw_state_cache_is_per_instance(): + from lerobot.policies.groot import processor_groot + + # The process-global raw-state cache was removed: a second pipeline's preprocess + # must not leak its reference state into the first pipeline's decode step. + assert not hasattr(processor_groot, "_N1_7_RAW_STATE_CACHE") + + raw_stats = { + "state": {"single_arm": _symmetric_unit_stats(5)}, + "action": {"single_arm": _symmetric_unit_stats(5)}, + "relative_action": {"single_arm": {"min": [-1.0] * 5, "max": [1.0] * 5}}, + } + modality_config = { + "state": {"modality_keys": ["single_arm"]}, + "action": { + "modality_keys": ["single_arm"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None} + ], + }, + } + + def build_pair(): + pack = GrootN17PackInputsStep( + raw_stats=raw_stats, modality_config=modality_config, normalize_min_max=False, max_state_dim=8 + ) + decode = GrootN17ActionDecodeStep( + env_action_dim=5, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + pack_step=pack, + ) + return pack, decode + + first_pack, first_decode = build_pair() + second_pack, _ = build_pair() + first_reference = torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0]]) + first_pack( + {TransitionKey.OBSERVATION: {OBS_STATE: first_reference}, TransitionKey.COMPLEMENTARY_DATA: {}} + ) + second_pack( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.full((1, 5), 99.0)}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + ) + + output = first_decode({TransitionKey.ACTION: torch.zeros(1, 5)}) + + torch.testing.assert_close(output[TransitionKey.ACTION], first_reference) + + +def test_groot_n1_7_action_decode_without_connected_pack_step_raises(): + raw_stats = { + "state": {"single_arm": _symmetric_unit_stats(5)}, + "action": {"single_arm": _symmetric_unit_stats(5)}, + "relative_action": {"single_arm": {"min": [-1.0] * 5, "max": [1.0] * 5}}, + } + modality_config = { + "state": {"modality_keys": ["single_arm"]}, + "action": { + "modality_keys": ["single_arm"], + "action_configs": [ + {"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None} + ], + }, + } + orphan_decode = GrootN17ActionDecodeStep( + env_action_dim=5, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=True, + ) + + with pytest.raises(RuntimeError, match="connected GrootN17PackInputsStep"): + orphan_decode({TransitionKey.ACTION: torch.zeros(1, 5)}) + + +def test_groot_n1_7_loaded_processors_reconnect_pack_and_decode_steps(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "saved" + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(save_dir), + preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}}, + ) + + # The pack/decode link is not serialized, so the factory must re-link the loaded + # decode step to the loaded pack step's per-instance raw-state cache. + pack_step = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + decode_step = next( + step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep) + ) + assert decode_step.pack_step is pack_step + + +# --------------------------------------------------------------------------- +# Raw checkpoint factory branch: caller overrides and hub repo loading +# --------------------------------------------------------------------------- + + +def test_raw_n1_7_checkpoint_processors_apply_caller_overrides(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + rename_map = {f"{OBS_IMAGES}.cam": f"{OBS_IMAGES}.image"} + + preprocessor, _ = make_pre_post_processors( + config, + pretrained_path=str(model_path), + preprocessor_overrides={"rename_observations_processor": {"rename_map": rename_map}}, + ) + + rename_step = next( + step for step in preprocessor.steps if isinstance(step, RenameObservationsProcessorStep) + ) + assert rename_step.rename_map == rename_map + + +def test_raw_n1_7_checkpoint_processors_reject_unknown_override_key(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + + with pytest.raises(KeyError, match="does not match any step"): + make_pre_post_processors( + config, + pretrained_path=str(model_path), + preprocessor_overrides={"missing_step": {"enabled": True}}, + ) + + +def test_raw_n1_7_checkpoint_processors_reject_unknown_override_field(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + + with pytest.raises(TypeError, match="is not a config field"): + make_pre_post_processors( + config, + pretrained_path=str(model_path), + preprocessor_overrides={"rename_observations_processor": {"bogus_field": 1}}, + ) + + +def test_converted_n1_7_processors_load_from_hub_repo_id_without_legacy_override_error(tmp_path, monkeypatch): + import lerobot.processor.pipeline as pipeline_module + from lerobot.policies.groot import processor_groot + + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "hub_repo" + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + def fake_hf_hub_download(repo_id=None, filename=None, **kwargs): + assert repo_id == "user/groot-finetune" + file_path = save_dir / filename + if not file_path.exists(): + raise FileNotFoundError(filename) + return str(file_path) + + monkeypatch.setattr(processor_groot, "hf_hub_download", fake_hf_hub_download) + monkeypatch.setattr(pipeline_module, "hf_hub_download", fake_hf_hub_download) + + # Loading from a hub repo id must inspect the serialized postprocessor and skip + # the legacy action-unpack override instead of raising + # KeyError("Override keys ['groot_action_unpack_unnormalize_v1'] ..."). + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path="user/groot-finetune", + ) + + assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps) + assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps) + + +def test_converted_n1_7_processors_retry_without_legacy_overrides_when_inspection_fails( + tmp_path, monkeypatch +): + from lerobot.policies.groot import processor_groot + + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "converted" + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + # Simulate the config inspection failing (e.g. offline without a cached copy): + # the injected legacy overrides then miss the serialized steps, and loading must + # fall back to the caller overrides instead of surfacing the KeyError. + monkeypatch.setattr(processor_groot, "_pretrained_processor_config_has_step", lambda *a, **k: False) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(save_dir), + preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}}, + ) + + assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps) + assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps) + + +# --------------------------------------------------------------------------- +# Camera matching warnings +# --------------------------------------------------------------------------- + + +def test_groot_n1_7_pack_inputs_warns_once_for_unmatched_and_dropped_cameras(caplog): + step = GrootN17PackInputsStep(normalize_min_max=False, video_modality_keys=["image", "side_image"]) + observation = { + f"{OBS_IMAGES}.image": torch.full((1, 3, 2, 2), 11, dtype=torch.uint8), + f"{OBS_IMAGES}.wrist": torch.full((1, 3, 2, 2), 22, dtype=torch.uint8), + OBS_STATE: torch.zeros(1, 8), + } + + with caplog.at_level(logging.WARNING): + step( + { + TransitionKey.OBSERVATION: dict(observation), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + ) + + unmatched_warnings = [ + record.getMessage() for record in caplog.records if "no matching camera" in record.getMessage() + ] + dropped_warnings = [ + record.getMessage() for record in caplog.records if "Dropping camera(s)" in record.getMessage() + ] + assert len(unmatched_warnings) == 1 + assert "side_image" in unmatched_warnings[0] + assert len(dropped_warnings) == 1 + assert f"{OBS_IMAGES}.wrist" in dropped_warnings[0] + + # The warnings are emitted once per step instance, not once per frame. + caplog.clear() + with caplog.at_level(logging.WARNING): + step( + { + TransitionKey.OBSERVATION: dict(observation), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + ) + assert not any("camera" in record.getMessage() for record in caplog.records) + + +def test_groot_n1_7_pack_inputs_does_not_warn_on_full_camera_match(caplog): + step = GrootN17PackInputsStep(normalize_min_max=False, video_modality_keys=["image", "wrist_image"]) + observation = { + f"{OBS_IMAGES}.image": torch.full((1, 3, 2, 2), 11, dtype=torch.uint8), + # LIBERO conversions expose the wrist camera as image2; the alias must match silently. + f"{OBS_IMAGES}.image2": torch.full((1, 3, 2, 2), 22, dtype=torch.uint8), + OBS_STATE: torch.zeros(1, 8), + } + + with caplog.at_level(logging.WARNING): + output = step( + {TransitionKey.OBSERVATION: observation, TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}} + ) + + assert output[TransitionKey.OBSERVATION]["video"].shape == (1, 1, 2, 2, 2, 3) + assert not any("camera" in record.getMessage() for record in caplog.records) + + +# --------------------------------------------------------------------------- +# GrootPolicy bf16 handling and GR00TN17 backbone loading kwargs +# --------------------------------------------------------------------------- + + +def test_groot_policy_use_bf16_false_does_not_touch_model_compute_dtype(monkeypatch): + pytest.importorskip("transformers") + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + # _DummyGrootModel deliberately defines no compute_dtype, like the real GR00TN17: + # with use_bf16=False the policy must not read or set the attribute. + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel())) + + policy = GrootPolicy(_groot_config(GROOT_N1_7)) + + assert not hasattr(policy._groot_model, "compute_dtype") + assert not hasattr(policy._groot_model.config, "compute_dtype") + + +def test_groot_policy_use_bf16_true_sets_model_compute_dtype(monkeypatch): + pytest.importorskip("transformers") + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel())) + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + config = GrootConfig( + model_version=GROOT_N1_7, + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=True, + ) + + policy = GrootPolicy(config) + + assert policy._groot_model.compute_dtype == "bfloat16" + assert policy._groot_model.config.compute_dtype == "bfloat16" + + +def _stub_gr00t_n1_7_loading(monkeypatch, called, snapshot_kwargs): + from huggingface_hub.errors import HFValidationError + + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + class FakeLoadedModel: + def __init__(self): + self.config = SimpleNamespace(tune_top_llm_layers=0) + self.backbone = SimpleNamespace(set_trainable_parameters=lambda **kwargs: None) + self.action_head = SimpleNamespace(set_trainable_parameters=lambda **kwargs: None) + + def fake_snapshot_download(*args, **kwargs): + snapshot_kwargs.update(kwargs) + raise HFValidationError("local path") + + def fake_super_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + called["pretrained_model_name_or_path"] = pretrained_model_name_or_path + called.update(kwargs) + return FakeLoadedModel() + + monkeypatch.setattr(groot_n1_7, "snapshot_download", fake_snapshot_download) + monkeypatch.setattr( + groot_n1_7.PreTrainedModel, + "from_pretrained", + classmethod(fake_super_from_pretrained), + ) + return groot_n1_7 + + +def test_gr00t_n1_7_from_pretrained_does_not_forward_revision_to_backbone(monkeypatch, tmp_path): + pytest.importorskip("transformers") + called: dict = {} + snapshot_kwargs: dict = {} + groot_n1_7 = _stub_gr00t_n1_7_loading(monkeypatch, called, snapshot_kwargs) + + groot_n1_7.GR00TN17.from_pretrained( + str(tmp_path), + revision="deadbeefcafe", + cache_dir=str(tmp_path / "cache"), + token="hf_dummy", + local_files_only=False, + ) + + # 'revision' pins the GR00T checkpoint repo and must not leak into the unrelated + # backbone repo load; the repo-agnostic hub kwargs are still forwarded. + transformers_loading_kwargs = called["transformers_loading_kwargs"] + assert "revision" not in transformers_loading_kwargs + assert transformers_loading_kwargs["cache_dir"] == str(tmp_path / "cache") + assert transformers_loading_kwargs["local_files_only"] is False + assert transformers_loading_kwargs["token"] == "hf_dummy" + assert snapshot_kwargs["revision"] == "deadbeefcafe" + + +def test_gr00t_n1_7_from_pretrained_preserves_explicit_backbone_revision(monkeypatch, tmp_path): + pytest.importorskip("transformers") + called: dict = {} + snapshot_kwargs: dict = {} + groot_n1_7 = _stub_gr00t_n1_7_loading(monkeypatch, called, snapshot_kwargs) + + groot_n1_7.GR00TN17.from_pretrained( + str(tmp_path), + revision="deadbeefcafe", + transformers_loading_kwargs={"revision": "backbone-tag", "trust_remote_code": True}, + ) + + assert called["transformers_loading_kwargs"]["revision"] == "backbone-tag" + + +def test_get_backbone_cls_warns_only_for_unrecognized_qwen_backbone_names(caplog): + pytest.importorskip("transformers") + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + with caplog.at_level(logging.WARNING, logger=groot_n1_7.__name__): + recognized = groot_n1_7.get_backbone_cls( + groot_n1_7.GR00TN17Config(model_name="nvidia/Cosmos-Reason2-2B") + ) + assert recognized is groot_n1_7.Qwen3Backbone + assert not any("Unrecognized" in record.getMessage() for record in caplog.records) + + # Local snapshot paths carry no recognized repo marker; with the default + # backbone_model_type='qwen' they must load with a loud assumption warning. + with caplog.at_level(logging.WARNING, logger=groot_n1_7.__name__): + local = groot_n1_7.get_backbone_cls(groot_n1_7.GR00TN17Config(model_name="/local/backbone/snapshot")) + assert local is groot_n1_7.Qwen3Backbone + warnings = [ + record.getMessage() + for record in caplog.records + if "Unrecognized GR00T N1.7 backbone model name" in record.getMessage() + ] + assert len(warnings) == 1 + + with pytest.raises(ValueError, match="Unsupported GR00T N1.7 backbone model"): + groot_n1_7.get_backbone_cls( + groot_n1_7.GR00TN17Config(model_name="totally/bogus", backbone_model_type=None) + )