From 46482e23b7d9ca587288514b79919aa64896f178 Mon Sep 17 00:00:00 2001 From: Virgile Date: Thu, 7 May 2026 10:34:09 +0200 Subject: [PATCH] enforce last state read when querry --- src/lerobot/motors/robstride/robstride.py | 108 +++++++++++++++++++++- 1 file changed, 104 insertions(+), 4 deletions(-) diff --git a/src/lerobot/motors/robstride/robstride.py b/src/lerobot/motors/robstride/robstride.py index f825507b1..af219e335 100644 --- a/src/lerobot/motors/robstride/robstride.py +++ b/src/lerobot/motors/robstride/robstride.py @@ -509,6 +509,97 @@ class RobstrideMotorsBus(MotorsBusBase): return responses + def _recv_all_messages_until_quiet( + self, + *, + timeout: float, + max_messages: int = 4096, + ) -> list[can.Message]: + """ + Receive frames until the bus goes quiet. + + Args: + wait_for_first_s: Timeout waiting for the first frame. Use 0.0 for + non-blocking drain of already queued frames. + poll_timeout_s: Poll timeout used while draining remaining frames. + max_messages: Safety cap to prevent unbounded loops. + """ + out: list[can.Message] = [] + cap = max(1, int(max_messages)) + first_wait = max(0.0, float(timeout)) + poll_wait = max(0.0, float(timeout)) + + try: + first = self._bus().recv(timeout=first_wait) + if first is None: + return out + out.append(first) + + while len(out) < cap: + msg = self._bus().recv(timeout=poll_wait) + if msg is None: + break + out.append(msg) + except Exception as e: + logger.debug("Error draining CAN RX queue on %s: %s", self.port, e) + + return out + + def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]: + """ + Decode all received feedback frames and update cached motor states. + + Returns: + Set of payload recv_ids that were successfully mapped to motors. + """ + processed_recv_ids: set[int] = set() + for msg in messages: + if len(msg.data) < 1: + logger.debug( + "Dropping short CAN frame on %s (arb=0x%02X, data=%s)", + self.port, + int(msg.arbitration_id), + bytes(msg.data).hex(), + ) + continue + + recv_id = int(msg.data[0]) + motor_name = self._recv_id_to_motor.get(recv_id) + if motor_name is None: + logger.debug( + "Unmapped CAN frame on %s (arb=0x%02X, recv_id=0x%02X, data=%s)", + self.port, + int(msg.arbitration_id), + recv_id, + bytes(msg.data).hex(), + ) + continue + + self._process_response(motor_name, msg) + processed_recv_ids.add(recv_id) + + return processed_recv_ids + + def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int: + """ + Drain pending RX frames from the CAN interface. + + This is used by higher-level controllers to drop stale feedback before issuing + a fresh read cycle, so subsequent state reads are based on most recent replies. + """ + drained = 0 + timeout_s = max(0.0, float(poll_timeout_s)) + cap = max(1, int(max_messages)) + try: + while drained < cap: + msg = self._bus().recv(timeout=timeout_s) + if msg is None: + break + drained += 1 + except Exception as e: + logger.debug("Failed to flush CAN RX queue on %s: %s", self.port, e) + return drained + def _speed_control( self, motor: NameOrID, @@ -648,11 +739,14 @@ class RobstrideMotorsBus(MotorsBusBase): msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) self._bus().send(msg) recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + # Read every feedback frame until RX goes quiet, then decode all of them. + # This avoids dropping useful frames when responses from different motors interleave. + messages = self._recv_all_messages_until_quiet(timeout=RUNNING_TIMEOUT) + processed_recv_ids = self._process_feedback_messages(messages) - responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT) for recv_id, motor_name in recv_id_to_motor.items(): - if msg := responses.get(recv_id): - self._process_response(motor_name, msg) + if recv_id not in processed_recv_ids: + logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.") def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: """Convert float to unsigned integer for CAN transmission.""" @@ -715,7 +809,13 @@ class RobstrideMotorsBus(MotorsBusBase): try: self._decode_motor_state(msg.data) except Exception as e: - logger.warning(f"Failed to decode response from {motor}: {e}") + logger.warning( + "Failed to decode response from %s (arb=0x%02X, data=%s): %s", + motor, + int(msg.arbitration_id), + bytes(msg.data).hex(), + e, + ) def _get_cached_value(self, motor: str, data_name: str) -> Value: """Retrieve a specific value from the state cache."""