diff --git a/src/lerobot/motors/dynamixel/__init__.py b/src/lerobot/motors/dynamixel/__init__.py index b3ceec192..01fcadf4f 100644 --- a/src/lerobot/motors/dynamixel/__init__.py +++ b/src/lerobot/motors/dynamixel/__init__.py @@ -14,10 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.utils.import_utils import require_package - -require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk") - from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode from .tables import * # noqa: F403 — hardware constant tables diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index bca455dc5..b5aa3ee11 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -21,6 +21,9 @@ import logging from copy import deepcopy from enum import Enum +from typing import TYPE_CHECKING + +from lerobot.utils.import_utils import is_package_available, require_package from ..encoding_utils import decode_twos_complement, encode_twos_complement from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address @@ -33,6 +36,13 @@ from .tables import ( MODEL_RESOLUTION, ) +_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk") + +if TYPE_CHECKING or _dynamixel_sdk_available: + import dynamixel_sdk as dxl +else: + dxl = None + PROTOCOL_VERSION = 2.0 DEFAULT_BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 @@ -83,8 +93,6 @@ class TorqueMode(Enum): def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import dynamixel_sdk as dxl - if length == 1: data = [value] elif length == 2: @@ -123,9 +131,8 @@ class DynamixelMotorsBus(SerialMotorsBus): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): + require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk") super().__init__(port, motors, calibration) - import dynamixel_sdk as dxl - self.port_handler = dxl.PortHandler(self.port) self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) diff --git a/src/lerobot/motors/feetech/__init__.py b/src/lerobot/motors/feetech/__init__.py index fbea347fa..6c06d8b95 100644 --- a/src/lerobot/motors/feetech/__init__.py +++ b/src/lerobot/motors/feetech/__init__.py @@ -14,10 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.utils.import_utils import require_package - -require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk") - from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode from .tables import * # noqa: F403 — hardware constant tables diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 58a65310d..629c0877e 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -16,6 +16,9 @@ import logging from copy import deepcopy from enum import Enum from pprint import pformat +from typing import TYPE_CHECKING + +from lerobot.utils.import_utils import is_package_available, require_package from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address @@ -32,6 +35,13 @@ from .tables import ( SCAN_BAUDRATES, ) +_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk") + +if TYPE_CHECKING or _feetech_sdk_available: + import scservo_sdk as scs +else: + scs = None + DEFAULT_PROTOCOL_VERSION = 0 DEFAULT_BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 @@ -66,8 +76,6 @@ class TorqueMode(Enum): def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import scservo_sdk as scs - if length == 1: data = [value] elif length == 2: @@ -119,11 +127,10 @@ class FeetechMotorsBus(SerialMotorsBus): calibration: dict[str, MotorCalibration] | None = None, protocol_version: int = DEFAULT_PROTOCOL_VERSION, ): + require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk") super().__init__(port, motors, calibration) self.protocol_version = protocol_version self._assert_same_protocol() - import scservo_sdk as scs - self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] @@ -195,8 +202,6 @@ class FeetechMotorsBus(SerialMotorsBus): raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: - import scservo_sdk as scs - model = self.motors[motor].model search_baudrates = ( [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] @@ -332,8 +337,6 @@ class FeetechMotorsBus(SerialMotorsBus): return _split_into_byte_chunks(value, length) def _broadcast_ping(self) -> tuple[dict[int, int], int]: - import scservo_sdk as scs - data_list: dict[int, int] = {} status_length = 6 diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py index ca40bbe78..32e9f64fe 100755 --- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -14,26 +14,39 @@ # limitations under the License. +from typing import TYPE_CHECKING + import torch import torch.nn.functional as F # noqa: N812 +from torch import nn -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import is_package_available, require_package -require_package("diffusers", extra="groot") +_diffusers_available = is_package_available("diffusers") -from diffusers import ConfigMixin, ModelMixin # noqa: E402 -from diffusers.configuration_utils import register_to_config # noqa: E402 -from diffusers.models.attention import Attention, FeedForward # noqa: E402 -from diffusers.models.embeddings import ( # noqa: E402 - SinusoidalPositionalEmbedding, - TimestepEmbedding, - Timesteps, -) -from torch import nn # noqa: E402 +if TYPE_CHECKING or _diffusers_available: + from diffusers import ConfigMixin, ModelMixin + from diffusers.configuration_utils import register_to_config + from diffusers.models.attention import Attention, FeedForward + from diffusers.models.embeddings import ( + SinusoidalPositionalEmbedding, + TimestepEmbedding, + Timesteps, + ) +else: + ConfigMixin = object + ModelMixin = nn.Module + register_to_config = lambda fn: fn # noqa: E731 + Attention = None + FeedForward = None + SinusoidalPositionalEmbedding = None + TimestepEmbedding = None + Timesteps = None class TimestepEncoder(nn.Module): def __init__(self, embedding_dim, compute_dtype=torch.float32): + require_package("diffusers", extra="groot") super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -93,6 +106,7 @@ class BasicTransformerBlock(nn.Module): ff_bias: bool = True, attention_out_bias: bool = True, ): + require_package("diffusers", extra="groot") super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index a9e53200d..17fca2a44 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -16,10 +16,14 @@ import math import pickle import time -import numpy as np -import torch +import pytest -from lerobot.async_inference.helpers import ( +pytest.importorskip("grpc") + +import numpy as np # noqa: E402 +import torch # noqa: E402 + +from lerobot.async_inference.helpers import ( # noqa: E402 FPSTracker, TimedAction, TimedObservation, diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py index b6716fbd6..cf3d6cdca 100644 --- a/tests/rl/test_queue.py +++ b/tests/rl/test_queue.py @@ -18,9 +18,13 @@ import threading import time from queue import Queue -from torch.multiprocessing import Queue as TorchMPQueue +import pytest -from lerobot.rl.queue import get_last_item_from_queue +pytest.importorskip("grpc") + +from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 + +from lerobot.rl.queue import get_last_item_from_queue # noqa: E402 def test_get_last_item_single_item(): diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index e2b00cae9..ce56db173 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -22,7 +22,9 @@ from unittest.mock import patch import pytest -from lerobot.rl.process import ProcessSignalHandler +pytest.importorskip("grpc") + +from lerobot.rl.process import ProcessSignalHandler # noqa: E402 # Fixture to reset shutdown_event_counter and original signal handlers before and after each test diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b9d3a1ac0..6085fb7fb 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -18,12 +18,15 @@ import sys from collections.abc import Callable import pytest -import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD -from tests.fixtures.constants import DUMMY_REPO_ID +pytest.importorskip("grpc") + +import torch # noqa: E402 + +from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402 +from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402 +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402 +from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402 def state_dims() -> list[str]: