adding guards to avoid needing transformers and diffusers for type checking and basic tests

This commit is contained in:
Maximellerbach
2026-05-11 19:05:26 +02:00
parent ff7dc0519e
commit 96a8efddf0
4 changed files with 40 additions and 6 deletions
+19 -4
View File
@@ -1,16 +1,31 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from torch import nn from torch import nn
from torch.distributions import Beta from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig from .configuration_vla_jepa import VLAJEPAConfig
@@ -2,17 +2,24 @@ from __future__ import annotations
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from PIL import Image from PIL import Image
from torch import Tensor, nn from torch import Tensor, nn
from transformers import AutoModel, AutoVideoProcessor
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig from .configuration_vla_jepa import VLAJEPAConfig
@@ -43,6 +50,7 @@ class VLAJEPAModel(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None: def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__() super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config self.config = config
# Vision-language backbone # Vision-language backbone
@@ -1,11 +1,19 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig from .configuration_vla_jepa import VLAJEPAConfig
+3
View File
@@ -17,6 +17,9 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
pytest.importorskip("transformers")
pytest.importorskip("diffusers")
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning" "ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
) )