mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Make transport module Mypy Compliant [issue#1731] (#2433)
* latest * Delete =3.0.0 Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> * Update src/lerobot/transport/utils.py Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> --------- Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
797cd2725a
commit
0b497fc37d
+3
-3
@@ -360,9 +360,9 @@ ignore_errors = false
|
|||||||
# module = "lerobot.async_inference.*"
|
# module = "lerobot.async_inference.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.transport.*"
|
module = "lerobot.transport.*"
|
||||||
# ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.scripts.*"
|
# module = "lerobot.scripts.*"
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import io
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle # nosec B403: Safe usage for internal serialization only
|
import pickle # nosec B403: Safe usage for internal serialization only
|
||||||
from multiprocessing import Event
|
from multiprocessing.synchronize import Event as MpEvent
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -28,6 +28,9 @@ import torch
|
|||||||
from lerobot.transport import services_pb2
|
from lerobot.transport import services_pb2
|
||||||
from lerobot.utils.transition import Transition
|
from lerobot.utils.transition import Transition
|
||||||
|
|
||||||
|
# FIX for protobuf: Assign the enum to a variable and ignore the type error once
|
||||||
|
TransferState = services_pb2.TransferState # type: ignore[attr-defined]
|
||||||
|
|
||||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||||
|
|
||||||
@@ -40,8 +43,8 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||||
buffer = io.BytesIO(buffer)
|
bytes_buffer: io.BytesIO = io.BytesIO(buffer)
|
||||||
size_in_bytes = bytes_buffer_size(buffer)
|
size_in_bytes = bytes_buffer_size(bytes_buffer)
|
||||||
|
|
||||||
sent_bytes = 0
|
sent_bytes = 0
|
||||||
|
|
||||||
@@ -50,15 +53,15 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
|||||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||||
|
|
||||||
while sent_bytes < size_in_bytes:
|
while sent_bytes < size_in_bytes:
|
||||||
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
|
transfer_state = TransferState.TRANSFER_MIDDLE
|
||||||
|
|
||||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||||
transfer_state = services_pb2.TransferState.TRANSFER_END
|
transfer_state = TransferState.TRANSFER_END
|
||||||
elif sent_bytes == 0:
|
elif sent_bytes == 0:
|
||||||
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
|
transfer_state = TransferState.TRANSFER_BEGIN
|
||||||
|
|
||||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||||
chunk = buffer.read(size_to_read)
|
chunk = bytes_buffer.read(size_to_read)
|
||||||
|
|
||||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||||
sent_bytes += size_to_read
|
sent_bytes += size_to_read
|
||||||
@@ -67,7 +70,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
|||||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||||
|
|
||||||
|
|
||||||
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
|
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: MpEvent, log_prefix: str = ""):
|
||||||
bytes_buffer = io.BytesIO()
|
bytes_buffer = io.BytesIO()
|
||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
@@ -78,17 +81,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
|
|||||||
logging.info(f"{log_prefix} Shutting down receiver")
|
logging.info(f"{log_prefix} Shutting down receiver")
|
||||||
return
|
return
|
||||||
|
|
||||||
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
|
if item.transfer_state == TransferState.TRANSFER_BEGIN:
|
||||||
bytes_buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
bytes_buffer.truncate(0)
|
bytes_buffer.truncate(0)
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
logging.debug(f"{log_prefix} Received data at step 0")
|
logging.debug(f"{log_prefix} Received data at step 0")
|
||||||
step = 0
|
step = 0
|
||||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
|
elif item.transfer_state == TransferState.TRANSFER_MIDDLE:
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
step += 1
|
step += 1
|
||||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
|
elif item.transfer_state == TransferState.TRANSFER_END:
|
||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||||
|
|
||||||
@@ -109,17 +112,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
|
|||||||
|
|
||||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||||
"""Convert model state dict to flat array for transmission"""
|
"""Convert model state dict to flat array for transmission"""
|
||||||
buffer = io.BytesIO()
|
bytes_buffer = io.BytesIO()
|
||||||
|
|
||||||
torch.save(state_dict, buffer)
|
torch.save(state_dict, bytes_buffer)
|
||||||
|
|
||||||
return buffer.getvalue()
|
return bytes_buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||||
buffer = io.BytesIO(buffer)
|
bytes_buffer = io.BytesIO(buffer)
|
||||||
buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
return torch.load(buffer, weights_only=True)
|
return torch.load(bytes_buffer, weights_only=True)
|
||||||
|
|
||||||
|
|
||||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||||
@@ -127,24 +130,24 @@ def python_object_to_bytes(python_object: Any) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||||
buffer = io.BytesIO(buffer)
|
bytes_buffer = io.BytesIO(buffer)
|
||||||
buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
obj = pickle.load(bytes_buffer) # nosec B301: Safe usage of pickle.load
|
||||||
# Add validation checks here
|
# Add validation checks here
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||||
buffer = io.BytesIO(buffer)
|
bytes_buffer = io.BytesIO(buffer)
|
||||||
buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
transitions = torch.load(buffer, weights_only=True)
|
transitions = torch.load(bytes_buffer, weights_only=True)
|
||||||
return transitions
|
return transitions
|
||||||
|
|
||||||
|
|
||||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||||
buffer = io.BytesIO()
|
bytes_buffer = io.BytesIO()
|
||||||
torch.save(transitions, buffer)
|
torch.save(transitions, bytes_buffer)
|
||||||
return buffer.getvalue()
|
return bytes_buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def grpc_channel_options(
|
def grpc_channel_options(
|
||||||
|
|||||||
Reference in New Issue
Block a user