fix test imports

This commit is contained in:
Steven Palma
2026-04-11 21:07:53 +02:00
parent c9636bb53f
commit 5940126fb5
9 changed files with 71 additions and 42 deletions
-4
View File
@@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.utils.import_utils import require_package
require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
from .tables import * # noqa: F403 — hardware constant tables from .tables import * # noqa: F403 — hardware constant tables
+11 -4
View File
@@ -21,6 +21,9 @@
import logging import logging
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import is_package_available, require_package
from ..encoding_utils import decode_twos_complement, encode_twos_complement from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
@@ -33,6 +36,13 @@ from .tables import (
MODEL_RESOLUTION, MODEL_RESOLUTION,
) )
_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk")
if TYPE_CHECKING or _dynamixel_sdk_available:
import dynamixel_sdk as dxl
else:
dxl = None
PROTOCOL_VERSION = 2.0 PROTOCOL_VERSION = 2.0
DEFAULT_BAUDRATE = 1_000_000 DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000 DEFAULT_TIMEOUT_MS = 1000
@@ -83,8 +93,6 @@ class TorqueMode(Enum):
def _split_into_byte_chunks(value: int, length: int) -> list[int]: def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import dynamixel_sdk as dxl
if length == 1: if length == 1:
data = [value] data = [value]
elif length == 2: elif length == 2:
@@ -123,9 +131,8 @@ class DynamixelMotorsBus(SerialMotorsBus):
motors: dict[str, Motor], motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None, calibration: dict[str, MotorCalibration] | None = None,
): ):
require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
super().__init__(port, motors, calibration) super().__init__(port, motors, calibration)
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port) self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
-4
View File
@@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.utils.import_utils import require_package
require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
from .tables import * # noqa: F403 — hardware constant tables from .tables import * # noqa: F403 — hardware constant tables
+11 -8
View File
@@ -16,6 +16,9 @@ import logging
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from pprint import pformat from pprint import pformat
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import is_package_available, require_package
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
@@ -32,6 +35,13 @@ from .tables import (
SCAN_BAUDRATES, SCAN_BAUDRATES,
) )
_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk")
if TYPE_CHECKING or _feetech_sdk_available:
import scservo_sdk as scs
else:
scs = None
DEFAULT_PROTOCOL_VERSION = 0 DEFAULT_PROTOCOL_VERSION = 0
DEFAULT_BAUDRATE = 1_000_000 DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000 DEFAULT_TIMEOUT_MS = 1000
@@ -66,8 +76,6 @@ class TorqueMode(Enum):
def _split_into_byte_chunks(value: int, length: int) -> list[int]: def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import scservo_sdk as scs
if length == 1: if length == 1:
data = [value] data = [value]
elif length == 2: elif length == 2:
@@ -119,11 +127,10 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration: dict[str, MotorCalibration] | None = None, calibration: dict[str, MotorCalibration] | None = None,
protocol_version: int = DEFAULT_PROTOCOL_VERSION, protocol_version: int = DEFAULT_PROTOCOL_VERSION,
): ):
require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
super().__init__(port, motors, calibration) super().__init__(port, motors, calibration)
self.protocol_version = protocol_version self.protocol_version = protocol_version
self._assert_same_protocol() self._assert_same_protocol()
import scservo_sdk as scs
self.port_handler = scs.PortHandler(self.port) self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch # HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
@@ -195,8 +202,6 @@ class FeetechMotorsBus(SerialMotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
import scservo_sdk as scs
model = self.motors[motor].model model = self.motors[motor].model
search_baudrates = ( search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
@@ -332,8 +337,6 @@ class FeetechMotorsBus(SerialMotorsBus):
return _split_into_byte_chunks(value, length) return _split_into_byte_chunks(value, length)
def _broadcast_ping(self) -> tuple[dict[int, int], int]: def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list: dict[int, int] = {} data_list: dict[int, int] = {}
status_length = 6 status_length = 6
@@ -14,26 +14,39 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import nn
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import is_package_available, require_package
require_package("diffusers", extra="groot") _diffusers_available = is_package_available("diffusers")
from diffusers import ConfigMixin, ModelMixin # noqa: E402 if TYPE_CHECKING or _diffusers_available:
from diffusers.configuration_utils import register_to_config # noqa: E402 from diffusers import ConfigMixin, ModelMixin
from diffusers.models.attention import Attention, FeedForward # noqa: E402 from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import ( # noqa: E402 from diffusers.models.attention import Attention, FeedForward
SinusoidalPositionalEmbedding, from diffusers.models.embeddings import (
TimestepEmbedding, SinusoidalPositionalEmbedding,
Timesteps, TimestepEmbedding,
) Timesteps,
from torch import nn # noqa: E402 )
else:
ConfigMixin = object
ModelMixin = nn.Module
register_to_config = lambda fn: fn # noqa: E731
Attention = None
FeedForward = None
SinusoidalPositionalEmbedding = None
TimestepEmbedding = None
Timesteps = None
class TimestepEncoder(nn.Module): class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32): def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot")
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
@@ -93,6 +106,7 @@ class BasicTransformerBlock(nn.Module):
ff_bias: bool = True, ff_bias: bool = True,
attention_out_bias: bool = True, attention_out_bias: bool = True,
): ):
require_package("diffusers", extra="groot")
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
+7 -3
View File
@@ -16,10 +16,14 @@ import math
import pickle import pickle
import time import time
import numpy as np import pytest
import torch
from lerobot.async_inference.helpers import ( pytest.importorskip("grpc")
import numpy as np # noqa: E402
import torch # noqa: E402
from lerobot.async_inference.helpers import ( # noqa: E402
FPSTracker, FPSTracker,
TimedAction, TimedAction,
TimedObservation, TimedObservation,
+6 -2
View File
@@ -18,9 +18,13 @@ import threading
import time import time
from queue import Queue from queue import Queue
from torch.multiprocessing import Queue as TorchMPQueue import pytest
from lerobot.rl.queue import get_last_item_from_queue pytest.importorskip("grpc")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
def test_get_last_item_single_item(): def test_get_last_item_single_item():
+3 -1
View File
@@ -22,7 +22,9 @@ from unittest.mock import patch
import pytest import pytest
from lerobot.rl.process import ProcessSignalHandler pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test # Fixture to reset shutdown_event_counter and original signal handlers before and after each test
+8 -5
View File
@@ -18,12 +18,15 @@ import sys
from collections.abc import Callable from collections.abc import Callable
import pytest import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset pytest.importorskip("grpc")
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD import torch # noqa: E402
from tests.fixtures.constants import DUMMY_REPO_ID
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
def state_dims() -> list[str]: def state_dims() -> list[str]: