chore(linter): ensure motors module passes MyPy type checks (#2939)

* fix: ensure motors module passes MyPy type checks

This commit fixes 62 mypy type errors in the motors module by:

- Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead,
  GroupSyncWrite) to use class-level attribute declarations instead of
  __init__ body declarations
- Adding missing `broadcastPing` method to PacketHandler Protocol
- Fixing return type annotations (e.g., `_get_motor_model` returns str, not int)
- Fixing parameter types to use `Sequence` for covariant list parameters
- Fixing `Mapping` for covariant dict value types in `_normalize`
- Updating method signatures to be consistent across parent and child classes
  (disable_torque, enable_torque, _get_half_turn_homings)
- Adding explicit `int()` casts for MotorCalibration arguments
- Adding explicit `return None` for functions returning Optional types
- Adding type annotations for variables like `data_list: dict[int, int]`
- Using `# type: ignore[method-assign]` for intentional monkeypatch
- Fixing variable references (using `self.groups` instead of `groups`)

Fixes #1723

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* chore(style): pre-commit after main merge

* chore(linter): solve comments

* chore(linter): apply pre-commit fixes to damiao

* chore(linter): more fixes to damiao

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Steven Palma
2026-02-10 17:35:39 +01:00
committed by GitHub
parent 778db19a17
commit 35363c5798
6 changed files with 157 additions and 118 deletions
+3 -3
View File
@@ -360,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*" module = "lerobot.cameras.*"
ignore_errors = false ignore_errors = false
# [[tool.mypy.overrides]] [[tool.mypy.overrides]]
# module = "lerobot.motors.*" module = "lerobot.motors.*"
# ignore_errors = false ignore_errors = false
# [[tool.mypy.overrides]] # [[tool.mypy.overrides]]
# module = "lerobot.robots.*" # module = "lerobot.robots.*"
+6 -4
View File
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)} self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(groups) self.group_names = list(self.groups)
self.current_group = self.group_names[0] self.current_group = self.group_names[0]
if not bus.is_connected: if not bus.is_connected:
@@ -230,18 +230,20 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration() self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table self.res_table = bus.model_resolution_table
self.present_cache = { self.present_cache = {
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
} }
pygame.init() pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE) self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
self.label_pad = label_pad self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height)) self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder") pygame.display.set_caption("Motors range finder")
+38 -4
View File
@@ -211,6 +211,9 @@ class DamiaoMotorsBus(MotorsBusBase):
logger.info("Starting handshake with motors...") logger.info("Starting handshake with motors...")
# Drain any pending messages # Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01): while self.canbus.recv(timeout=0.01):
pass pass
@@ -283,6 +286,10 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor) recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte] data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg) self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id): if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg) self._process_response(motor_name, msg)
@@ -341,6 +348,10 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor) recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg) self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id) return self._recv_motor_response(expected_recv_id=recv_id)
@@ -356,6 +367,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns: Returns:
CAN message if received, None otherwise CAN message if received, None otherwise
""" """
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try: try:
start_time = time.time() start_time = time.time()
messages_seen = [] messages_seen = []
@@ -394,10 +409,13 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns: Returns:
Dictionary mapping recv_id to CAN message Dictionary mapping recv_id to CAN message
""" """
responses = {} responses: dict[int, can.Message] = {}
expected_set = set(expected_recv_ids) expected_set = set(expected_recv_ids)
start_time = time.time() start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try: try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout # 100us poll timeout
@@ -461,6 +479,9 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor) motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name] motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg) self.canbus.send(msg)
@@ -488,6 +509,9 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {} recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands # Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor) motor_id = self._get_motor_id(motor)
@@ -656,6 +680,10 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None: def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache.""" """Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands # Send refresh commands
for motor in motors: for motor in motors:
motor_id = self._get_motor_id(motor) motor_id = self._get_motor_id(motor)
@@ -678,10 +706,14 @@ class DamiaoMotorsBus(MotorsBusBase):
else: else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
""" """
Write values to multiple motors simultaneously. Positions are always in degrees. Write values to multiple motors simultaneously. Positions are always in degrees.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"): if data_name in ("Kp", "Kd"):
key = data_name.lower() key = data_name.lower()
for motor, val in values.items(): for motor, val in values.items():
@@ -690,6 +722,8 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position": elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands # Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {} recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items(): for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor) motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor) motor_name = self._get_motor_name(motor)
@@ -732,9 +766,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion( def record_ranges_of_motion(
self, self,
motors: NameOrID | list[NameOrID] | None = None, motors: str | list[str] | None = None,
display_values: bool = True, display_values: bool = True,
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: ) -> tuple[dict[str, Value], dict[str, Value]]:
""" """
Interactively record the min/max values of each motor in degrees. Interactively record the min/max values of each motor in degrees.
+8 -8
View File
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items(): for motor, m in self.motors.items():
calibration[motor] = MotorCalibration( calibration[motor] = MotorCalibration(
id=m.id, id=m.id,
drive_mode=drive_modes[motor], drive_mode=int(drive_modes[motor]),
homing_offset=offsets[motor], homing_offset=int(offsets[motor]),
range_min=mins[motor], range_min=int(mins[motor]),
range_max=maxes[motor], range_max=int(maxes[motor]),
) )
return calibration return calibration
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache: if cache:
self.calibration = calibration_dict self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors): for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors): for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors: On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset Present_Position = Actual_Position + Homing_Offset
""" """
half_turn_homings = {} half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items(): for motor, pos in positions.items():
model = self._get_motor_model(motor) model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1 max_res = self.model_resolution_table[model] - 1
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return return None
return {id_: data[0] for id_, data in data_list.items()} return {id_: data[0] for id_, data in data_list.items()}
+9 -9
View File
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port) self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch # HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler, scs.PortHandler self.port_handler, scs.PortHandler
) )
self.packet_handler = scs.PacketHandler(protocol_version) self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration( calibration[motor] = MotorCalibration(
id=m.id, id=m.id,
drive_mode=0, drive_mode=0,
homing_offset=offsets[motor], homing_offset=int(offsets[motor]),
range_min=mins[motor], range_min=int(mins[motor]),
range_max=maxes[motor], range_max=int(maxes[motor]),
) )
return calibration return calibration
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors: On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset Present_Position = Actual_Position - Homing_Offset
""" """
half_turn_homings = {} half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items(): for motor, pos in positions.items():
model = self._get_motor_model(motor) model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1 max_res = self.model_resolution_table[model] - 1
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors): for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry)
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Lock") addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry) self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors): for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]: def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs import scservo_sdk as scs
data_list = {} data_list: dict[int, int] = {}
status_length = 6 status_length = 6
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm): if not self._is_comm_success(comm):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return return None
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors: if ids_errors:
+93 -90
View File
@@ -23,6 +23,7 @@ from __future__ import annotations
import abc import abc
import logging import logging
from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""Write values to multiple motors.""" """Write values to multiple motors."""
pass pass
@@ -179,15 +180,16 @@ class Motor:
class PortHandler(Protocol): class PortHandler(Protocol):
def __init__(self, port_name): is_open: bool
self.is_open: bool baudrate: int
self.baudrate: int packet_start_time: float
self.packet_start_time: float packet_timeout: float
self.packet_timeout: float tx_time_per_byte: float
self.tx_time_per_byte: float is_using: bool
self.is_using: bool port_name: str
self.port_name: str ser: serial.Serial
self.ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def openPort(self): ... def openPort(self): ...
def closePort(self): ... def closePort(self): ...
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ... def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol): class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length): port: str
self.port: str ph: PortHandler
self.ph: PortHandler start_address: int
self.start_address: int data_length: int
self.data_length: int last_result: bool
self.last_result: bool is_param_changed: bool
self.is_param_changed: bool param: list
self.param: list data_dict: dict
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ... def makeParam(self): ...
def addParam(self, id): ... def addParam(self, id): ...
def removeParam(self, id): ... def removeParam(self, id): ...
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol): class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length): port: str
self.port: str ph: PortHandler
self.ph: PortHandler start_address: int
self.start_address: int data_length: int
self.data_length: int is_param_changed: bool
self.is_param_changed: bool param: list
self.param: list data_dict: dict
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ... def makeParam(self): ...
def addParam(self, id, data): ... def addParam(self, id, data): ...
def removeParam(self, id): ... def removeParam(self, id): ...
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
else: else:
raise TypeError(f"'{motor}' should be int, str.") raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> int: def _get_motor_model(self, motor: NameOrID) -> str:
if isinstance(motor, str): if isinstance(motor, str):
return self.motors[motor].model return self.motors[motor].model
elif isinstance(motor, int): elif isinstance(motor, int):
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
else: else:
raise TypeError(f"'{motor}' should be int, str.") raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
if motors is None: if motors is None:
return list(self.motors) return list(self.motors)
elif isinstance(motors, str): elif isinstance(motors, str):
return [motors] return [motors]
elif isinstance(motors, list): elif isinstance(motors, int):
return motors.copy() return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
else: else:
raise TypeError(motors) raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
if isinstance(values, (int | float)): if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values) return dict.fromkeys(self.ids, values)
elif isinstance(values, dict): elif isinstance(values, dict):
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
pass pass
@abc.abstractmethod @abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors. """Enable torque on selected motors.
Args: Args:
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure. num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0. Defaults to 0.
""" """
pass pass
@contextmanager @contextmanager
def torque_disabled(self, motors: int | str | list[str] | None = None): def torque_disabled(self, motors: str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled. """Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors. This helper is useful to temporarily disable torque when configuring motors.
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
""" """
pass pass
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors. """Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range. Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared. The in-memory :pyattr:`calibration` is cleared.
Args: Args:
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor. resets every motor.
""" """
if motors is None: motor_names = self._get_motors_list(motors)
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motors: for motor in motor_names:
model = self._get_motor_model(motor) model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1 max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False) self.write("Homing_Offset", motor, 0, normalize=False)
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {} self.calibration = {}
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position. """Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one The function computes and writes a homing offset such that the present position becomes exactly one
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns: Returns:
dict[NameOrID, Value]: Mapping *motor written homing offset*. dict[str, Value]: Mapping *motor name written homing offset*.
""" """
if motors is None: motor_names = self._get_motors_list(motors)
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motors) self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motors, normalize=False) actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions) homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items(): for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset) self.write("Homing_Offset", motor, offset)
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
pass pass
def record_ranges_of_motion( def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: ) -> tuple[dict[str, Value], dict[str, Value]]:
"""Interactively record the min/max encoder values of each motor. """Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console. display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns: Returns:
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor. extreme values observed for each motor.
""" """
if motors is None: motor_names = self._get_motors_list(motors)
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motors, normalize=False) start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = start_positions.copy() mins = start_positions.copy()
maxes = start_positions.copy() maxes = start_positions.copy()
user_pressed_enter = False user_pressed_enter = False
while not user_pressed_enter: while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False) positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values: if display_values:
print("\n-------------------------------------------") print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motors: for motor in motor_names:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed(): if enter_pressed():
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter: if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output # Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3) move_cursor_up(len(motor_names) + 3)
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
if same_min_max: if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error: if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm)) raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else: else:
return return None
if self._is_error(error): if self._is_error(error):
if raise_on_error: if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error)) raise RuntimeError(self.packet_handler.getRxPacketError(error))
else: else:
return return None
return model_number return model_number
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value}) decoded = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data: if normalize and data_name in self.normalized_data:
id_value = self._normalize(id_value) normalized = self._normalize(decoded)
return normalized[id_]
return id_value[id_] return decoded[id_]
def _read( def _read(
self, self,
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0, num_retry: int = 0,
raise_on_error: bool = True, raise_on_error: bool = True,
err_msg: str = "", err_msg: str = "",
) -> tuple[int, int]: ) -> tuple[int, int, int]:
if length == 1: if length == 1:
read_fn = self.packet_handler.read1ByteTxRx read_fn = self.packet_handler.read1ByteTxRx
elif length == 2: elif length == 2:
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name) addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data: if normalize and data_name in self.normalized_data:
value = self._unnormalize({id_: value})[id_] int_value = self._unnormalize({id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_] int_value = self._encode_sign(data_name, {id_: int_value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write( def _write(
self, self,
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read( def sync_read(
self, self,
data_name: str, data_name: str,
motors: str | list[str] | None = None, motors: NameOrID | Sequence[NameOrID] | None = None,
*, *,
normalize: bool = True, normalize: bool = True,
num_retry: int = 0, num_retry: int = 0,
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
Args: Args:
data_name (str): Register name. data_name (str): Register name.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`. normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`. num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name) addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read( raw_ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
) )
ids_values = self._decode_sign(data_name, ids_values) decoded = self._decode_sign(data_name, raw_ids_values)
if normalize and data_name in self.normalized_data: if normalize and data_name in self.normalized_data:
ids_values = self._normalize(ids_values) normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()} return {self._id_to_name(id_): value for id_, value in decoded.items()}
def _sync_read( def _sync_read(
self, self,
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`. num_retry (int, optional): Retry attempts. Defaults to `0`.
""" """
ids_values = self._get_ids_values_dict(values) raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values] models = [self._id_to_model(id_) for id_ in raw_ids_values]
if self._has_different_ctrl_tables: if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models)) model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name) addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data: if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(ids_values) int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._encode_sign(data_name, ids_values) int_ids_values = self._encode_sign(data_name, int_ids_values)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
def _sync_write( def _sync_write(
self, self,