From 96a8efddf06db89cc8d04197fa2d3997521628f6 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Mon, 11 May 2026 19:05:26 +0200 Subject: [PATCH] adding guards to avoid needing transformers and diffusers for type checking and basic tests --- src/lerobot/policies/vla_jepa/action_head.py | 23 +++++++++++++++---- .../policies/vla_jepa/modeling_vla_jepa.py | 10 +++++++- .../policies/vla_jepa/qwen_interface.py | 10 +++++++- tests/policies/vla_jepa/test_vla_jepa.py | 3 +++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 76d105194..ee9d2d7da 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -1,16 +1,31 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING import torch import torch.nn.functional as F # noqa: N812 -from diffusers import ConfigMixin, ModelMixin -from diffusers.configuration_utils import register_to_config -from diffusers.models.attention import Attention, FeedForward -from diffusers.models.embeddings import TimestepEmbedding, Timesteps from torch import nn from torch.distributions import Beta +from lerobot.utils.import_utils import _diffusers_available + +if TYPE_CHECKING or _diffusers_available: + from diffusers import ConfigMixin, ModelMixin + from diffusers.configuration_utils import register_to_config + from diffusers.models.attention import Attention, FeedForward + from diffusers.models.embeddings import TimestepEmbedding, Timesteps +else: + + class ModelMixin: # type: ignore[no-redef] + pass + + class ConfigMixin: # type: ignore[no-redef] + pass + + register_to_config = lambda f: f # noqa: E731 + Attention = FeedForward = TimestepEmbedding = Timesteps = None + from .configuration_vla_jepa import VLAJEPAConfig diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index aa178be51..21b55987c 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -2,17 +2,24 @@ from __future__ import annotations from collections import deque from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import torch import torch.nn.functional as F # noqa: N812 from PIL import Image from torch import Tensor, nn -from transformers import AutoModel, AutoVideoProcessor from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.utils import populate_queues from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoModel, AutoVideoProcessor +else: + AutoModel = None + AutoVideoProcessor = None from .action_head import VLAJEPAActionHead from .configuration_vla_jepa import VLAJEPAConfig @@ -43,6 +50,7 @@ class VLAJEPAModel(nn.Module): def __init__(self, config: VLAJEPAConfig) -> None: super().__init__() + require_package("transformers", extra="vla_jepa") self.config = config # Vision-language backbone diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index 044b6f989..592ecad93 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -1,11 +1,19 @@ from __future__ import annotations from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np import torch from PIL import Image -from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor, Qwen3VLForConditionalGeneration +else: + AutoProcessor = None + Qwen3VLForConditionalGeneration = None from .configuration_vla_jepa import VLAJEPAConfig diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 48c8ab9b4..ffec4c201 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -17,6 +17,9 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +pytest.importorskip("transformers") +pytest.importorskip("diffusers") + pytestmark = pytest.mark.filterwarnings( "ignore:In CPU autocast, but the target dtype is not supported:UserWarning" )