mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
adding guards to avoid needing transformers and diffusers for type checking and basic tests
This commit is contained in:
@@ -1,16 +1,31 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from diffusers import ConfigMixin, ModelMixin
|
|
||||||
from diffusers.configuration_utils import register_to_config
|
|
||||||
from diffusers.models.attention import Attention, FeedForward
|
|
||||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import Beta
|
from torch.distributions import Beta
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _diffusers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _diffusers_available:
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
from diffusers.models.attention import Attention, FeedForward
|
||||||
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
else:
|
||||||
|
|
||||||
|
class ModelMixin: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConfigMixin: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
register_to_config = lambda f: f # noqa: E731
|
||||||
|
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||||
|
|
||||||
from .configuration_vla_jepa import VLAJEPAConfig
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import AutoModel, AutoVideoProcessor
|
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.utils import populate_queues
|
from lerobot.policies.utils import populate_queues
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoModel, AutoVideoProcessor
|
||||||
|
else:
|
||||||
|
AutoModel = None
|
||||||
|
AutoVideoProcessor = None
|
||||||
|
|
||||||
from .action_head import VLAJEPAActionHead
|
from .action_head import VLAJEPAActionHead
|
||||||
from .configuration_vla_jepa import VLAJEPAConfig
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
@@ -43,6 +50,7 @@ class VLAJEPAModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
require_package("transformers", extra="vla_jepa")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Vision-language backbone
|
# Vision-language backbone
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||||
|
else:
|
||||||
|
AutoProcessor = None
|
||||||
|
Qwen3VLForConditionalGeneration = None
|
||||||
|
|
||||||
from .configuration_vla_jepa import VLAJEPAConfig
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|||||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
pytest.importorskip("diffusers")
|
||||||
|
|
||||||
pytestmark = pytest.mark.filterwarnings(
|
pytestmark = pytest.mark.filterwarnings(
|
||||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user