From 68f869b7a00d5d1b1ff9583d93224cd82fd07043 Mon Sep 17 00:00:00 2001 From: nv-sachdevkartik Date: Fri, 12 Jun 2026 08:36:49 +0000 Subject: [PATCH] test(groot): adopt test_groot_lerobot for GR00T N1.7, drop N1.5 The test loaded MODEL_PATH='aractingi/bimanual-handover-groot-10k', an N1.5 checkpoint (config base_model_path=nvidia/GR00T-N1.5-3B, no model_version). On load, model_version defaults to n1.7 while the base path infers n1.5, so the version-consistency guard in GrootConfig.__post_init__ raised ValueError and both test_lerobot_groot_inference and test_lerobot_groot_forward_pass failed. N1.5 is no longer a supported model_version. Adopt the test for N1.7: - MODEL_PATH -> nvidia/GR00T-N1.7-3B (root-level sharded safetensors; loads via GrootPolicy.from_pretrained as a base N1.7 model). - Embodiment tag 'gr1' (N1.5) -> 'gr1_unified' (valid N1.7 tag from the checkpoint embodiment_id.json), via a single EMBODIMENT_TAG constant. - DUMMY_ACTION_HORIZON 16 -> 40 to match N1.7's native action-chunk size. - Docstrings/labels updated to 'GR00T N1.7'. Both tests run and pass on CUDA; full tests/policies/groot/ suite is 73 passed / 0 failed / 0 skipped. --- tests/policies/groot/test_groot_lerobot.py | 29 ++++++++++++++-------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 788935d4f..34acdef2f 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script for LeRobot's Groot policy forward and inference passes.""" +"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes.""" import gc import os @@ -41,13 +41,20 @@ pytestmark = pytest.mark.skipif( ) -# Define constants for dummy data +# Define constants for dummy data (GR00T N1.7 native conventions). +# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images +# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch +# matches the model's native action space. DUMMY_STATE_DIM = 44 DUMMY_ACTION_DIM = 44 -DUMMY_ACTION_HORIZON = 16 +DUMMY_ACTION_HORIZON = 40 IMAGE_SIZE = 256 DEVICE = auto_select_torch_device() -MODEL_PATH = "aractingi/bimanual-handover-groot-10k" +# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads +# via GrootPolicy.from_pretrained with root-level sharded safetensors. +MODEL_PATH = "nvidia/GR00T-N1.7-3B" +# Valid N1.7 embodiment tag carried by the checkpoint metadata. +EMBODIMENT_TAG = "gr1_unified" def cleanup_memory(): @@ -88,13 +95,13 @@ def instantiate_lerobot_groot( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: - """Instantiate LeRobot Groot policy with preprocessor and postprocessor.""" + """Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor.""" if from_pretrained: policy = GrootPolicy.from_pretrained( pretrained_name_or_path=model_path, strict=False, ) - policy.config.embodiment_tag = "gr1" + policy.config.embodiment_tag = EMBODIMENT_TAG else: config = GrootConfig( base_model_path=model_path, @@ -102,7 +109,7 @@ def instantiate_lerobot_groot( chunk_size=DUMMY_ACTION_HORIZON, image_size=[IMAGE_SIZE, IMAGE_SIZE], device=DEVICE, - embodiment_tag="gr1", + embodiment_tag=EMBODIMENT_TAG, ) policy = GrootPolicy(config) @@ -148,8 +155,8 @@ def create_dummy_data(device=DEVICE): @require_cuda def test_lerobot_groot_inference(): - """Test the inference pass (select_action) of LeRobot's Groot policy.""" - print("Test: LeRobot Groot Inference Pass") + """Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy.""" + print("Test: LeRobot GR00T N1.7 Inference Pass") set_seed_all(42) @@ -181,9 +188,9 @@ def test_lerobot_groot_inference(): @require_cuda def test_lerobot_groot_forward_pass(): - """Test the forward pass of LeRobot's Groot policy.""" + """Test the forward pass of LeRobot's GR00T N1.7 policy.""" print("\n" + "=" * 50) - print("Test: LeRobot Groot Forward Pass (Training Mode)") + print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)") set_seed_all(42)