From f4b834844e34f7dcf8245e17818e6bdd9ebdcaff Mon Sep 17 00:00:00 2001 From: Virgileboat <116651491+Virgileboat@users.noreply.github.com> Date: Thu, 21 May 2026 11:44:04 +0200 Subject: [PATCH] Feat/clean can bus (#3526) * change timeout for handshake * enforce last state read when querry * change import order * fix(motors): flush stale robstride RX and harden feedback drain * robstride: remove redundant timeout and max_messages casts * bugfix + %-style * update exception catch --- src/lerobot/motors/robstride/robstride.py | 118 ++++++++++++++++++---- src/lerobot/motors/robstride/tables.py | 3 +- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/src/lerobot/motors/robstride/robstride.py b/src/lerobot/motors/robstride/robstride.py index ecde01e9a..359fc9385 100644 --- a/src/lerobot/motors/robstride/robstride.py +++ b/src/lerobot/motors/robstride/robstride.py @@ -43,6 +43,7 @@ from .tables import ( CAN_CMD_SET_ZERO, DEFAULT_BAUDRATE, DEFAULT_TIMEOUT_MS, + HANDSHAKE_TIMEOUT_S, MODEL_RESOLUTION, MOTOR_LIMIT_PARAMS, NORMALIZED_DATA, @@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase): self._is_connected = False raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e - def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]: + def _query_status_via_clear_fault( + self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT + ) -> tuple[bool, can.Message | None]: motor_name = self._get_motor_name(motor) motor_id = self._get_motor_id(motor_name) recv_id = self._get_motor_recv_id(motor_name) data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) self._bus().send(msg) - return self._recv_status_via_clear_fault(expected_recv_id=recv_id) + return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout) def _recv_status_via_clear_fault( self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT @@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase): faulted_motors = [] for motor_name in self.motors: - has_fault, msg = self._query_status_via_clear_fault(motor_name) + has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S) if msg is None: missing_motors.append(motor_name) elif has_fault: @@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase): return responses + def _recv_all_messages_until_quiet( + self, + *, + timeout: float = RUNNING_TIMEOUT, + max_messages: int = 4096, + ) -> list[can.Message]: + """ + Receive frames until the bus goes quiet. + + Args: + timeout: Poll timeout used for each recv() call. Collection stops + when one recv() times out (quiet gap). + max_messages: Safety cap to prevent unbounded loops. + """ + out: list[can.Message] = [] + max_messages = max(1, max_messages) + timeout = max(0.0, timeout) + + try: + while len(out) < max_messages: + msg = self._bus().recv(timeout=timeout) + if msg is None: + break + out.append(msg) + except (can.CanError, OSError) as e: + logger.debug(f"Error draining CAN RX queue on {self.port}: {e}") + + return out + + def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]: + """ + Decode all received feedback frames and update cached motor states. + + Returns: + Set of payload recv_ids that were successfully mapped to motors. + """ + processed_recv_ids: set[int] = set() + for msg in messages: + if len(msg.data) < 1: + logger.debug( + f"Dropping short CAN frame on {self.port} " + f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})" + ) + continue + + recv_id = int(msg.data[0]) + motor_name = self._recv_id_to_motor.get(recv_id) + if motor_name is None: + logger.debug( + f"Unmapped CAN frame on {self.port} " + f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})" + ) + continue + + self._process_response(motor_name, msg) + processed_recv_ids.add(recv_id) + + return processed_recv_ids + + def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int: + """ + Drain pending RX frames from the CAN interface. + + This is used by higher-level controllers to drop stale feedback before issuing + a fresh read cycle, so subsequent state reads are based on most recent replies. + It should also be called once when a controller instance is created/connected, + to clear residual frames left on the interface from previous sessions. + """ + drained = 0 + poll_timeout_s = max(0.0, poll_timeout_s) + max_messages = max(1, max_messages) + try: + while drained < max_messages: + msg = self._bus().recv(timeout=poll_timeout_s) + if msg is None: + break + drained += 1 + except (can.CanError, OSError) as e: + logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}") + return drained + def _speed_control( self, motor: NameOrID, @@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase): msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) self._bus().send(msg) recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + # Read every feedback frame until RX goes quiet, then decode all of them. + # This avoids dropping useful frames when responses from different motors interleave. + messages = self._recv_all_messages_until_quiet() + processed_recv_ids = self._process_feedback_messages(messages) - responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT) for recv_id, motor_name in recv_id_to_motor.items(): - if msg := responses.get(recv_id): - self._process_response(motor_name, msg) + if recv_id not in processed_recv_ids: + logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.") def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: """Convert float to unsigned integer for CAN transmission.""" @@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase): try: self._decode_motor_state(msg.data) except Exception as e: - logger.warning(f"Failed to decode response from {motor}: {e}") + logger.warning( + f"Failed to decode response from {motor} " + f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}" + ) def _get_cached_value(self, motor: str, data_name: str) -> Value: """Retrieve a specific value from the state cache.""" @@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase): self._bus().send(msg) updated_motors.append(motor) - expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors] - responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT) - - for response in responses.values(): - payload_motor_name = self._recv_id_to_motor.get(response.data[0]) - if payload_motor_name is not None: - self._process_response(payload_motor_name, response) - else: - # Fallback: still attempt to decode based on payload byte0 mapping. - self._decode_motor_state(response.data) + messages = self._recv_all_messages_until_quiet() + processed_recv_ids = self._process_feedback_messages(messages) for motor in updated_motors: recv_id = self._get_motor_recv_id(motor) - if recv_id not in responses: + if recv_id not in processed_recv_ids: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") def read_calibration(self) -> dict[str, MotorCalibration]: diff --git a/src/lerobot/motors/robstride/tables.py b/src/lerobot/motors/robstride/tables.py index 2fc1a97b0..06b90df3a 100644 --- a/src/lerobot/motors/robstride/tables.py +++ b/src/lerobot/motors/robstride/tables.py @@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA CAN_PARAM_ID = 0x7FF -RUNNING_TIMEOUT = 0.001 +RUNNING_TIMEOUT = 0.003 +HANDSHAKE_TIMEOUT_S = 0.05 PARAM_TIMEOUT = 0.01 STATE_CACHE_TTL_S = 0.02