Make transport module Mypy Compliant [issue#1731] (#2433)

* latest

* Delete =3.0.0

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>

* Update src/lerobot/transport/utils.py

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>

---------

Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Md. Muhaimin Rahman
2025-12-03 06:12:15 +09:00
committed by GitHub
parent 797cd2725a
commit 0b497fc37d
2 changed files with 32 additions and 29 deletions
+3 -3
View File
@@ -360,9 +360,9 @@ ignore_errors = false
# module = "lerobot.async_inference.*" # module = "lerobot.async_inference.*"
# ignore_errors = false # ignore_errors = false
# [[tool.mypy.overrides]] [[tool.mypy.overrides]]
# module = "lerobot.transport.*" module = "lerobot.transport.*"
# ignore_errors = false ignore_errors = false
# [[tool.mypy.overrides]] # [[tool.mypy.overrides]]
# module = "lerobot.scripts.*" # module = "lerobot.scripts.*"
+29 -26
View File
@@ -19,7 +19,7 @@ import io
import json import json
import logging import logging
import pickle # nosec B403: Safe usage for internal serialization only import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event from multiprocessing.synchronize import Event as MpEvent
from queue import Queue from queue import Queue
from typing import Any from typing import Any
@@ -28,6 +28,9 @@ import torch
from lerobot.transport import services_pb2 from lerobot.transport import services_pb2
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
# FIX for protobuf: Assign the enum to a variable and ignore the type error once
TransferState = services_pb2.TransferState # type: ignore[attr-defined]
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
@@ -40,8 +43,8 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer = io.BytesIO(buffer) bytes_buffer: io.BytesIO = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer) size_in_bytes = bytes_buffer_size(bytes_buffer)
sent_bytes = 0 sent_bytes = 0
@@ -50,15 +53,15 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes: while sent_bytes < size_in_bytes:
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE transfer_state = TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes: if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = services_pb2.TransferState.TRANSFER_END transfer_state = TransferState.TRANSFER_END
elif sent_bytes == 0: elif sent_bytes == 0:
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN transfer_state = TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read) chunk = bytes_buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk) yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read sent_bytes += size_to_read
@@ -67,7 +70,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""): def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: MpEvent, log_prefix: str = ""):
bytes_buffer = io.BytesIO() bytes_buffer = io.BytesIO()
step = 0 step = 0
@@ -78,17 +81,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
logging.info(f"{log_prefix} Shutting down receiver") logging.info(f"{log_prefix} Shutting down receiver")
return return
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: if item.transfer_state == TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0) bytes_buffer.seek(0)
bytes_buffer.truncate(0) bytes_buffer.truncate(0)
bytes_buffer.write(item.data) bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step 0") logging.debug(f"{log_prefix} Received data at step 0")
step = 0 step = 0
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: elif item.transfer_state == TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data) bytes_buffer.write(item.data)
step += 1 step += 1
logging.debug(f"{log_prefix} Received data at step {step}") logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: elif item.transfer_state == TransferState.TRANSFER_END:
bytes_buffer.write(item.data) bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
@@ -109,17 +112,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission""" """Convert model state dict to flat array for transmission"""
buffer = io.BytesIO() bytes_buffer = io.BytesIO()
torch.save(state_dict, buffer) torch.save(state_dict, bytes_buffer)
return buffer.getvalue() return bytes_buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer) bytes_buffer = io.BytesIO(buffer)
buffer.seek(0) bytes_buffer.seek(0)
return torch.load(buffer, weights_only=True) return torch.load(bytes_buffer, weights_only=True)
def python_object_to_bytes(python_object: Any) -> bytes: def python_object_to_bytes(python_object: Any) -> bytes:
@@ -127,24 +130,24 @@ def python_object_to_bytes(python_object: Any) -> bytes:
def bytes_to_python_object(buffer: bytes) -> Any: def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer) bytes_buffer = io.BytesIO(buffer)
buffer.seek(0) bytes_buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load obj = pickle.load(bytes_buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here # Add validation checks here
return obj return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]: def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer) bytes_buffer = io.BytesIO(buffer)
buffer.seek(0) bytes_buffer.seek(0)
transitions = torch.load(buffer, weights_only=True) transitions = torch.load(bytes_buffer, weights_only=True)
return transitions return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes: def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO() bytes_buffer = io.BytesIO()
torch.save(transitions, buffer) torch.save(transitions, bytes_buffer)
return buffer.getvalue() return bytes_buffer.getvalue()
def grpc_channel_options( def grpc_channel_options(