mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 09:07:03 +00:00
fix(ci): guard dependecy checks
This commit is contained in:
@@ -103,6 +103,8 @@ def _raw_n1_7_libero_config(model_path) -> GrootConfig:
|
||||
|
||||
|
||||
def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_ids(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
@@ -196,6 +198,8 @@ def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_
|
||||
|
||||
|
||||
def test_n1_7_backbone_preserves_missing_qwen_optional_dependency_error(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -416,6 +420,8 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs():
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -445,6 +451,8 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -478,6 +486,8 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
@@ -531,6 +541,8 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
|
||||
|
||||
def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
@@ -545,6 +557,8 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc
|
||||
|
||||
|
||||
def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "local-checkpoint"
|
||||
@@ -2518,6 +2532,8 @@ def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
called = {}
|
||||
@@ -2535,6 +2551,8 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -2593,6 +2611,8 @@ def test_groot_select_action_rejects_relative_action_policies():
|
||||
|
||||
|
||||
def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
@@ -2785,6 +2805,8 @@ def test_qwen3_backbone_can_initialize_from_config_without_downloading_weights(m
|
||||
|
||||
|
||||
def test_gr00t_n1_7_from_pretrained_defers_backbone_weight_loading(monkeypatch, tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from huggingface_hub.errors import HFValidationError
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
|
||||
@@ -21,7 +21,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from lerobot.policies.groot.action_head.cross_attention_dit import AlternateVLDiT
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
@@ -110,6 +109,8 @@ def test_groot_n1_7_vlm_chat_content_order_matches_oss_reference():
|
||||
def test_groot_n1_7_alternate_vl_dit_matches_oss_reference():
|
||||
"""Run the LeRobot DiT with native OSS weights and identical inputs."""
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
fixture = torch.load(_fixture_path("alternate_vl_dit_small.pt"), map_location="cpu", weights_only=True)
|
||||
model = AlternateVLDiT(
|
||||
output_dim=8,
|
||||
@@ -228,6 +229,10 @@ def test_groot_n1_7_qwen_backbone_matches_oss_checkpoint_reference():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("The 3B OSS Qwen parity test requires CUDA.")
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
fixture = torch.load(_fixture_path("qwen_backbone_so101.pt"), map_location="cpu", weights_only=True)
|
||||
model = GR00TN17.from_pretrained(checkpoint).to(device="cuda", dtype=torch.bfloat16).eval()
|
||||
backbone_input = BatchFeature(
|
||||
|
||||
Reference in New Issue
Block a user