mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Merge pull request #12 from acwrenn53/exp/groot-n17-test-groot-lerobot
Adopt test_groot_lerobot for GR00T N1.7, drop N1.5
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for LeRobot's Groot policy forward and inference passes."""
|
||||
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
@@ -41,13 +41,20 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
# Define constants for dummy data
|
||||
# Define constants for dummy data (GR00T N1.7 native conventions).
|
||||
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
|
||||
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
|
||||
# matches the model's native action space.
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
DUMMY_ACTION_HORIZON = 40
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = auto_select_torch_device()
|
||||
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
|
||||
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
|
||||
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
|
||||
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
|
||||
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
|
||||
EMBODIMENT_TAG = "gr1_unified"
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
@@ -88,13 +95,13 @@ def instantiate_lerobot_groot(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
policy.config.embodiment_tag = EMBODIMENT_TAG
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
@@ -102,7 +109,7 @@ def instantiate_lerobot_groot(
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag="gr1",
|
||||
embodiment_tag=EMBODIMENT_TAG,
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
@@ -148,8 +155,8 @@ def create_dummy_data(device=DEVICE):
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_inference():
|
||||
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
|
||||
print("Test: LeRobot Groot Inference Pass")
|
||||
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
|
||||
print("Test: LeRobot GR00T N1.7 Inference Pass")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
@@ -181,9 +188,9 @@ def test_lerobot_groot_inference():
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_forward_pass():
|
||||
"""Test the forward pass of LeRobot's Groot policy."""
|
||||
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Test: LeRobot Groot Forward Pass (Training Mode)")
|
||||
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user