mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix test imports
This commit is contained in:
@@ -14,10 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
|
|
||||||
|
|
||||||
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
|
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
|
||||||
from .tables import * # noqa: F403 — hardware constant tables
|
from .tables import * # noqa: F403 — hardware constant tables
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import is_package_available, require_package
|
||||||
|
|
||||||
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
from ..encoding_utils import decode_twos_complement, encode_twos_complement
|
||||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||||
@@ -33,6 +36,13 @@ from .tables import (
|
|||||||
MODEL_RESOLUTION,
|
MODEL_RESOLUTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk")
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _dynamixel_sdk_available:
|
||||||
|
import dynamixel_sdk as dxl
|
||||||
|
else:
|
||||||
|
dxl = None
|
||||||
|
|
||||||
PROTOCOL_VERSION = 2.0
|
PROTOCOL_VERSION = 2.0
|
||||||
DEFAULT_BAUDRATE = 1_000_000
|
DEFAULT_BAUDRATE = 1_000_000
|
||||||
DEFAULT_TIMEOUT_MS = 1000
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
@@ -83,8 +93,6 @@ class TorqueMode(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||||
import dynamixel_sdk as dxl
|
|
||||||
|
|
||||||
if length == 1:
|
if length == 1:
|
||||||
data = [value]
|
data = [value]
|
||||||
elif length == 2:
|
elif length == 2:
|
||||||
@@ -123,9 +131,8 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
|||||||
motors: dict[str, Motor],
|
motors: dict[str, Motor],
|
||||||
calibration: dict[str, MotorCalibration] | None = None,
|
calibration: dict[str, MotorCalibration] | None = None,
|
||||||
):
|
):
|
||||||
|
require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
|
||||||
super().__init__(port, motors, calibration)
|
super().__init__(port, motors, calibration)
|
||||||
import dynamixel_sdk as dxl
|
|
||||||
|
|
||||||
self.port_handler = dxl.PortHandler(self.port)
|
self.port_handler = dxl.PortHandler(self.port)
|
||||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||||
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||||
|
|||||||
@@ -14,10 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
|
|
||||||
|
|
||||||
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
|
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
|
||||||
from .tables import * # noqa: F403 — hardware constant tables
|
from .tables import * # noqa: F403 — hardware constant tables
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ import logging
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import is_package_available, require_package
|
||||||
|
|
||||||
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||||
@@ -32,6 +35,13 @@ from .tables import (
|
|||||||
SCAN_BAUDRATES,
|
SCAN_BAUDRATES,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk")
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _feetech_sdk_available:
|
||||||
|
import scservo_sdk as scs
|
||||||
|
else:
|
||||||
|
scs = None
|
||||||
|
|
||||||
DEFAULT_PROTOCOL_VERSION = 0
|
DEFAULT_PROTOCOL_VERSION = 0
|
||||||
DEFAULT_BAUDRATE = 1_000_000
|
DEFAULT_BAUDRATE = 1_000_000
|
||||||
DEFAULT_TIMEOUT_MS = 1000
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
@@ -66,8 +76,6 @@ class TorqueMode(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
if length == 1:
|
if length == 1:
|
||||||
data = [value]
|
data = [value]
|
||||||
elif length == 2:
|
elif length == 2:
|
||||||
@@ -119,11 +127,10 @@ class FeetechMotorsBus(SerialMotorsBus):
|
|||||||
calibration: dict[str, MotorCalibration] | None = None,
|
calibration: dict[str, MotorCalibration] | None = None,
|
||||||
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
||||||
):
|
):
|
||||||
|
require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
|
||||||
super().__init__(port, motors, calibration)
|
super().__init__(port, motors, calibration)
|
||||||
self.protocol_version = protocol_version
|
self.protocol_version = protocol_version
|
||||||
self._assert_same_protocol()
|
self._assert_same_protocol()
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
self.port_handler = scs.PortHandler(self.port)
|
self.port_handler = scs.PortHandler(self.port)
|
||||||
# HACK: monkeypatch
|
# HACK: monkeypatch
|
||||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
||||||
@@ -195,8 +202,6 @@ class FeetechMotorsBus(SerialMotorsBus):
|
|||||||
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||||
|
|
||||||
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
model = self.motors[motor].model
|
model = self.motors[motor].model
|
||||||
search_baudrates = (
|
search_baudrates = (
|
||||||
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
||||||
@@ -332,8 +337,6 @@ class FeetechMotorsBus(SerialMotorsBus):
|
|||||||
return _split_into_byte_chunks(value, length)
|
return _split_into_byte_chunks(value, length)
|
||||||
|
|
||||||
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
data_list: dict[int, int] = {}
|
data_list: dict[int, int] = {}
|
||||||
|
|
||||||
status_length = 6
|
status_length = 6
|
||||||
|
|||||||
@@ -14,26 +14,39 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import is_package_available, require_package
|
||||||
|
|
||||||
require_package("diffusers", extra="groot")
|
_diffusers_available = is_package_available("diffusers")
|
||||||
|
|
||||||
from diffusers import ConfigMixin, ModelMixin # noqa: E402
|
if TYPE_CHECKING or _diffusers_available:
|
||||||
from diffusers.configuration_utils import register_to_config # noqa: E402
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
from diffusers.models.attention import Attention, FeedForward # noqa: E402
|
from diffusers.configuration_utils import register_to_config
|
||||||
from diffusers.models.embeddings import ( # noqa: E402
|
from diffusers.models.attention import Attention, FeedForward
|
||||||
SinusoidalPositionalEmbedding,
|
from diffusers.models.embeddings import (
|
||||||
TimestepEmbedding,
|
SinusoidalPositionalEmbedding,
|
||||||
Timesteps,
|
TimestepEmbedding,
|
||||||
)
|
Timesteps,
|
||||||
from torch import nn # noqa: E402
|
)
|
||||||
|
else:
|
||||||
|
ConfigMixin = object
|
||||||
|
ModelMixin = nn.Module
|
||||||
|
register_to_config = lambda fn: fn # noqa: E731
|
||||||
|
Attention = None
|
||||||
|
FeedForward = None
|
||||||
|
SinusoidalPositionalEmbedding = None
|
||||||
|
TimestepEmbedding = None
|
||||||
|
Timesteps = None
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(nn.Module):
|
class TimestepEncoder(nn.Module):
|
||||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||||
|
require_package("diffusers", extra="groot")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||||
@@ -93,6 +106,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
ff_bias: bool = True,
|
ff_bias: bool = True,
|
||||||
attention_out_bias: bool = True,
|
attention_out_bias: bool = True,
|
||||||
):
|
):
|
||||||
|
require_package("diffusers", extra="groot")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|||||||
@@ -16,10 +16,14 @@ import math
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.async_inference.helpers import (
|
pytest.importorskip("grpc")
|
||||||
|
|
||||||
|
import numpy as np # noqa: E402
|
||||||
|
import torch # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.async_inference.helpers import ( # noqa: E402
|
||||||
FPSTracker,
|
FPSTracker,
|
||||||
TimedAction,
|
TimedAction,
|
||||||
TimedObservation,
|
TimedObservation,
|
||||||
|
|||||||
@@ -18,9 +18,13 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
from torch.multiprocessing import Queue as TorchMPQueue
|
import pytest
|
||||||
|
|
||||||
from lerobot.rl.queue import get_last_item_from_queue
|
pytest.importorskip("grpc")
|
||||||
|
|
||||||
|
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def test_get_last_item_single_item():
|
def test_get_last_item_single_item():
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
pytest.importorskip("grpc")
|
||||||
|
|
||||||
|
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||||
|
|||||||
@@ -18,12 +18,15 @@ import sys
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
pytest.importorskip("grpc")
|
||||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
|
import torch # noqa: E402
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
|
||||||
|
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
|
||||||
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
|
||||||
|
from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def state_dims() -> list[str]:
|
def state_dims() -> list[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user