diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 788935d4f..34acdef2f 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script for LeRobot's Groot policy forward and inference passes.""" +"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes.""" import gc import os @@ -41,13 +41,20 @@ pytestmark = pytest.mark.skipif( ) -# Define constants for dummy data +# Define constants for dummy data (GR00T N1.7 native conventions). +# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images +# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch +# matches the model's native action space. DUMMY_STATE_DIM = 44 DUMMY_ACTION_DIM = 44 -DUMMY_ACTION_HORIZON = 16 +DUMMY_ACTION_HORIZON = 40 IMAGE_SIZE = 256 DEVICE = auto_select_torch_device() -MODEL_PATH = "aractingi/bimanual-handover-groot-10k" +# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads +# via GrootPolicy.from_pretrained with root-level sharded safetensors. +MODEL_PATH = "nvidia/GR00T-N1.7-3B" +# Valid N1.7 embodiment tag carried by the checkpoint metadata. +EMBODIMENT_TAG = "gr1_unified" def cleanup_memory(): @@ -88,13 +95,13 @@ def instantiate_lerobot_groot( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: - """Instantiate LeRobot Groot policy with preprocessor and postprocessor.""" + """Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor.""" if from_pretrained: policy = GrootPolicy.from_pretrained( pretrained_name_or_path=model_path, strict=False, ) - policy.config.embodiment_tag = "gr1" + policy.config.embodiment_tag = EMBODIMENT_TAG else: config = GrootConfig( base_model_path=model_path, @@ -102,7 +109,7 @@ def instantiate_lerobot_groot( chunk_size=DUMMY_ACTION_HORIZON, image_size=[IMAGE_SIZE, IMAGE_SIZE], device=DEVICE, - embodiment_tag="gr1", + embodiment_tag=EMBODIMENT_TAG, ) policy = GrootPolicy(config) @@ -148,8 +155,8 @@ def create_dummy_data(device=DEVICE): @require_cuda def test_lerobot_groot_inference(): - """Test the inference pass (select_action) of LeRobot's Groot policy.""" - print("Test: LeRobot Groot Inference Pass") + """Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy.""" + print("Test: LeRobot GR00T N1.7 Inference Pass") set_seed_all(42) @@ -181,9 +188,9 @@ def test_lerobot_groot_inference(): @require_cuda def test_lerobot_groot_forward_pass(): - """Test the forward pass of LeRobot's Groot policy.""" + """Test the forward pass of LeRobot's GR00T N1.7 policy.""" print("\n" + "=" * 50) - print("Test: LeRobot Groot Forward Pass (Training Mode)") + print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)") set_seed_all(42)