mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
chore(linter): ensure motors module passes MyPy type checks (#2939)
* fix: ensure motors module passes MyPy type checks This commit fixes 62 mypy type errors in the motors module by: - Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead, GroupSyncWrite) to use class-level attribute declarations instead of __init__ body declarations - Adding missing `broadcastPing` method to PacketHandler Protocol - Fixing return type annotations (e.g., `_get_motor_model` returns str, not int) - Fixing parameter types to use `Sequence` for covariant list parameters - Fixing `Mapping` for covariant dict value types in `_normalize` - Updating method signatures to be consistent across parent and child classes (disable_torque, enable_torque, _get_half_turn_homings) - Adding explicit `int()` casts for MotorCalibration arguments - Adding explicit `return None` for functions returning Optional types - Adding type annotations for variables like `data_list: dict[int, int]` - Using `# type: ignore[method-assign]` for intentional monkeypatch - Fixing variable references (using `self.groups` instead of `groups`) Fixes #1723 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * chore(style): pre-commit after main merge * chore(linter): solve comments * chore(linter): apply pre-commit fixes to damiao * chore(linter): more fixes to damiao --------- Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+3
-3
@@ -360,9 +360,9 @@ ignore_errors = false
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.motors.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
|
||||
@@ -221,7 +221,7 @@ class RangeFinderGUI:
|
||||
|
||||
self.bus = bus
|
||||
self.groups = groups if groups is not None else {"all": list(bus.motors)}
|
||||
self.group_names = list(groups)
|
||||
self.group_names = list(self.groups)
|
||||
self.current_group = self.group_names[0]
|
||||
|
||||
if not bus.is_connected:
|
||||
@@ -230,18 +230,20 @@ class RangeFinderGUI:
|
||||
self.calibration = bus.read_calibration()
|
||||
self.res_table = bus.model_resolution_table
|
||||
self.present_cache = {
|
||||
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
|
||||
m: bus.read("Present_Position", m, normalize=False)
|
||||
for motors in self.groups.values()
|
||||
for m in motors
|
||||
}
|
||||
|
||||
pygame.init()
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
|
||||
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
|
||||
self.label_pad = label_pad
|
||||
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
|
||||
self.controls_bottom = 10 + SAVE_H
|
||||
self.base_y = self.controls_bottom + TOP_GAP
|
||||
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
|
||||
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
|
||||
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Motors range finder")
|
||||
|
||||
@@ -211,6 +211,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
@@ -283,6 +286,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -341,6 +348,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
@@ -356,6 +367,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
@@ -394,10 +409,13 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
responses: dict[int, can.Message] = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
@@ -461,6 +479,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
@@ -488,6 +509,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -656,6 +680,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -678,10 +706,14 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
@@ -690,6 +722,8 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
@@ -732,9 +766,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
motors: str | list[str] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
|
||||
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
drive_mode=int(drive_modes[motor]),
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
On Dynamixel Motors:
|
||||
Present_Position = Actual_Position + Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
return
|
||||
return None
|
||||
|
||||
return {id_: data[0] for id_, data in data_list.items()}
|
||||
|
||||
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
||||
self.port_handler, scs.PortHandler
|
||||
)
|
||||
self.packet_handler = scs.PacketHandler(protocol_version)
|
||||
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
data_list = {}
|
||||
data_list: dict[int, int] = {}
|
||||
|
||||
status_length = 6
|
||||
|
||||
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
return None
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
|
||||
@@ -23,6 +23,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@@ -179,15 +180,16 @@ class Motor:
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: serial.Serial
|
||||
is_open: bool
|
||||
baudrate: int
|
||||
packet_start_time: float
|
||||
packet_timeout: float
|
||||
tx_time_per_byte: float
|
||||
is_using: bool
|
||||
port_name: str
|
||||
ser: serial.Serial
|
||||
|
||||
def __init__(self, port_name: str) -> None: ...
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
|
||||
def regWriteTxRx(self, port, id, address, length, data): ...
|
||||
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
|
||||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
def broadcastPing(self, port): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
last_result: bool
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motor_model(self, motor: NameOrID) -> int:
|
||||
def _get_motor_model(self, motor: NameOrID) -> str:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].model
|
||||
elif isinstance(motor, int):
|
||||
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return list(self.motors)
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors.copy()
|
||||
elif isinstance(motors, int):
|
||||
return [self._id_to_name(motors)]
|
||||
elif isinstance(motors, Sequence):
|
||||
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors.
|
||||
|
||||
Args:
|
||||
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
|
||||
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
|
||||
Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: int | str | list[str] | None = None):
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""Context-manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
|
||||
"""Restore factory calibration for the selected motors.
|
||||
|
||||
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
|
||||
The in-memory :pyattr:`calibration` is cleared.
|
||||
|
||||
Args:
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
resets every motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
self.write("Homing_Offset", motor, 0, normalize=False)
|
||||
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
self.calibration = {}
|
||||
|
||||
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
|
||||
def set_half_turn_homings(
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None
|
||||
) -> dict[NameOrID, Value]:
|
||||
"""Centre each motor range around its current position.
|
||||
|
||||
The function computes and writes a homing offset such that the present position becomes exactly one
|
||||
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
|
||||
|
||||
Returns:
|
||||
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
|
||||
dict[str, Value]: Mapping *motor name → written homing offset*.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
self.reset_calibration(motor_names)
|
||||
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""Interactively record the min/max encoder values of each motor.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions. Press
|
||||
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Returns:
|
||||
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
extreme values observed for each motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
move_cursor_up(len(motor_names) + 3)
|
||||
|
||||
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
if self._is_error(error):
|
||||
if raise_on_error:
|
||||
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
|
||||
return model_number
|
||||
|
||||
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
decoded = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
id_value = self._normalize(id_value)
|
||||
normalized = self._normalize(decoded)
|
||||
return normalized[id_]
|
||||
|
||||
return id_value[id_]
|
||||
return decoded[id_]
|
||||
|
||||
def _read(
|
||||
self,
|
||||
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_value = int(value)
|
||||
if normalize and data_name in self.normalized_data:
|
||||
value = self._unnormalize({id_: value})[id_]
|
||||
int_value = self._unnormalize({id_: value})[id_]
|
||||
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
|
||||
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
motors: NameOrID | Sequence[NameOrID] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
Args:
|
||||
data_name (str): Register name.
|
||||
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
normalize (bool, optional): Normalisation flag. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
|
||||
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
ids_values, _ = self._sync_read(
|
||||
raw_ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
decoded = self._decode_sign(data_name, raw_ids_values)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._normalize(ids_values)
|
||||
normalized = self._normalize(decoded)
|
||||
return {self._id_to_name(id_): value for id_, value in normalized.items()}
|
||||
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
return {self._id_to_name(id_): value for id_, value in decoded.items()}
|
||||
|
||||
def _sync_read(
|
||||
self,
|
||||
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
raw_ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in raw_ids_values]
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._unnormalize(ids_values)
|
||||
int_ids_values = self._unnormalize(raw_ids_values)
|
||||
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
int_ids_values = self._encode_sign(data_name, int_ids_values)
|
||||
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
|
||||
self._sync_write(
|
||||
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
def _sync_write(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user