mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
Make transport module Mypy Compliant [issue#1731] (#2433)
* latest * Delete =3.0.0 Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> * Update src/lerobot/transport/utils.py Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> --------- Signed-off-by: Md. Muhaimin Rahman <sezan92@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
797cd2725a
commit
0b497fc37d
@@ -19,7 +19,7 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +28,9 @@ import torch
|
||||
from lerobot.transport import services_pb2
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
# FIX for protobuf: Assign the enum to a variable and ignore the type error once
|
||||
TransferState = services_pb2.TransferState # type: ignore[attr-defined]
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
|
||||
@@ -40,8 +43,8 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
|
||||
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
bytes_buffer: io.BytesIO = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(bytes_buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
@@ -50,15 +53,15 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
transfer_state = TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_END
|
||||
transfer_state = TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
|
||||
transfer_state = TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
chunk = bytes_buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
@@ -67,7 +70,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: MpEvent, log_prefix: str = ""):
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
@@ -78,17 +81,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
|
||||
if item.transfer_state == TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
elif item.transfer_state == TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
|
||||
elif item.transfer_state == TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
@@ -109,17 +112,17 @@ def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
bytes_buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
torch.save(state_dict, bytes_buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
return bytes_buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
bytes_buffer = io.BytesIO(buffer)
|
||||
bytes_buffer.seek(0)
|
||||
return torch.load(bytes_buffer, weights_only=True)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
@@ -127,24 +130,24 @@ def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
bytes_buffer = io.BytesIO(buffer)
|
||||
bytes_buffer.seek(0)
|
||||
obj = pickle.load(bytes_buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer, weights_only=True)
|
||||
bytes_buffer = io.BytesIO(buffer)
|
||||
bytes_buffer.seek(0)
|
||||
transitions = torch.load(bytes_buffer, weights_only=True)
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
bytes_buffer = io.BytesIO()
|
||||
torch.save(transitions, bytes_buffer)
|
||||
return bytes_buffer.getvalue()
|
||||
|
||||
|
||||
def grpc_channel_options(
|
||||
|
||||
Reference in New Issue
Block a user