diff --git a/pyproject.toml b/pyproject.toml index d89f433e8..638b2326f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,9 +360,9 @@ ignore_errors = false # module = "lerobot.async_inference.*" # ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.transport.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.transport.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" diff --git a/src/lerobot/transport/utils.py b/src/lerobot/transport/utils.py index 5c9f702fc..8da338044 100644 --- a/src/lerobot/transport/utils.py +++ b/src/lerobot/transport/utils.py @@ -19,7 +19,7 @@ import io import json import logging import pickle # nosec B403: Safe usage for internal serialization only -from multiprocessing import Event +from multiprocessing.synchronize import Event as MpEvent from queue import Queue from typing import Any @@ -28,6 +28,9 @@ import torch from lerobot.transport import services_pb2 from lerobot.utils.transition import Transition +# FIX for protobuf: Assign the enum to a variable and ignore the type error once +TransferState = services_pb2.TransferState # type: ignore[attr-defined] + CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB @@ -40,8 +43,8 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int: def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): - buffer = io.BytesIO(buffer) - size_in_bytes = bytes_buffer_size(buffer) + bytes_buffer: io.BytesIO = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(bytes_buffer) sent_bytes = 0 @@ -50,15 +53,15 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") while sent_bytes < size_in_bytes: - transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE + transfer_state = TransferState.TRANSFER_MIDDLE if sent_bytes + CHUNK_SIZE >= size_in_bytes: - transfer_state = services_pb2.TransferState.TRANSFER_END + transfer_state = TransferState.TRANSFER_END elif sent_bytes == 0: - transfer_state = services_pb2.TransferState.TRANSFER_BEGIN + transfer_state = TransferState.TRANSFER_BEGIN size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) - chunk = buffer.read(size_to_read) + chunk = bytes_buffer.read(size_to_read) yield message_class(transfer_state=transfer_state, data=chunk) sent_bytes += size_to_read @@ -67,7 +70,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") -def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""): +def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: MpEvent, log_prefix: str = ""): bytes_buffer = io.BytesIO() step = 0 @@ -78,17 +81,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event logging.info(f"{log_prefix} Shutting down receiver") return - if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: + if item.transfer_state == TransferState.TRANSFER_BEGIN: bytes_buffer.seek(0) bytes_buffer.truncate(0) bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step 0") step = 0 - elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: + elif item.transfer_state == TransferState.TRANSFER_MIDDLE: bytes_buffer.write(item.data) step += 1 logging.debug(f"{log_prefix} Received data at step {step}") - elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: + elif item.transfer_state == TransferState.TRANSFER_END: bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") @@ -109,17 +112,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: """Convert model state dict to flat array for transmission""" - buffer = io.BytesIO() + bytes_buffer = io.BytesIO() - torch.save(state_dict, buffer) + torch.save(state_dict, bytes_buffer) - return buffer.getvalue() + return bytes_buffer.getvalue() def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: - buffer = io.BytesIO(buffer) - buffer.seek(0) - return torch.load(buffer, weights_only=True) + bytes_buffer = io.BytesIO(buffer) + bytes_buffer.seek(0) + return torch.load(bytes_buffer, weights_only=True) def python_object_to_bytes(python_object: Any) -> bytes: @@ -127,24 +130,24 @@ def python_object_to_bytes(python_object: Any) -> bytes: def bytes_to_python_object(buffer: bytes) -> Any: - buffer = io.BytesIO(buffer) - buffer.seek(0) - obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load + bytes_buffer = io.BytesIO(buffer) + bytes_buffer.seek(0) + obj = pickle.load(bytes_buffer) # nosec B301: Safe usage of pickle.load # Add validation checks here return obj def bytes_to_transitions(buffer: bytes) -> list[Transition]: - buffer = io.BytesIO(buffer) - buffer.seek(0) - transitions = torch.load(buffer, weights_only=True) + bytes_buffer = io.BytesIO(buffer) + bytes_buffer.seek(0) + transitions = torch.load(bytes_buffer, weights_only=True) return transitions def transitions_to_bytes(transitions: list[Transition]) -> bytes: - buffer = io.BytesIO() - torch.save(transitions, buffer) - return buffer.getvalue() + bytes_buffer = io.BytesIO() + torch.save(transitions, bytes_buffer) + return bytes_buffer.getvalue() def grpc_channel_options(