Merge pull request #12 from acwrenn53/exp/groot-n17-test-groot-lerobot

Adopt test_groot_lerobot for GR00T N1.7, drop N1.5
This commit is contained in:
Kartik
2026-06-12 11:01:25 +02:00
committed by GitHub
+18 -11
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for LeRobot's Groot policy forward and inference passes."""
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
import gc
import os
@@ -41,13 +41,20 @@ pytestmark = pytest.mark.skipif(
)
# Define constants for dummy data
# Define constants for dummy data (GR00T N1.7 native conventions).
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
# matches the model's native action space.
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 16
DUMMY_ACTION_HORIZON = 40
IMAGE_SIZE = 256
DEVICE = auto_select_torch_device()
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
EMBODIMENT_TAG = "gr1_unified"
def cleanup_memory():
@@ -88,13 +95,13 @@ def instantiate_lerobot_groot(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = "gr1"
policy.config.embodiment_tag = EMBODIMENT_TAG
else:
config = GrootConfig(
base_model_path=model_path,
@@ -102,7 +109,7 @@ def instantiate_lerobot_groot(
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag="gr1",
embodiment_tag=EMBODIMENT_TAG,
)
policy = GrootPolicy(config)
@@ -148,8 +155,8 @@ def create_dummy_data(device=DEVICE):
@require_cuda
def test_lerobot_groot_inference():
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
print("Test: LeRobot Groot Inference Pass")
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
print("Test: LeRobot GR00T N1.7 Inference Pass")
set_seed_all(42)
@@ -181,9 +188,9 @@ def test_lerobot_groot_inference():
@require_cuda
def test_lerobot_groot_forward_pass():
"""Test the forward pass of LeRobot's Groot policy."""
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
print("\n" + "=" * 50)
print("Test: LeRobot Groot Forward Pass (Training Mode)")
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
set_seed_all(42)