refactor(rl): hoist grpcio guard to module top in actor/learner

This commit is contained in:
Khalil Meftah
2026-05-09 20:49:53 +02:00
parent a5222d3f1d
commit e7886c2285
3 changed files with 97 additions and 123 deletions
+47 -58
View File
@@ -52,74 +52,63 @@ import time
from collections.abc import Generator
from functools import lru_cache
from queue import Empty
from typing import TYPE_CHECKING, Any
from typing import Any
import torch
from torch import nn
from torch.multiprocessing import Queue
from lerobot.utils.import_utils import require_package
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
import grpc # noqa: E402
import torch # noqa: E402
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.configs import parser # noqa: E402
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.processor import TransitionKey # noqa: E402
from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.robot_utils import precise_sleep # noqa: E402
from lerobot.utils.transition import ( # noqa: E402
Transition,
move_transition_to_device,
)
from lerobot.utils.utils import (
from lerobot.utils.utils import ( # noqa: E402
TimerManager,
init_logging,
)
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
from .gym_manipulator import (
from .algorithms.base import RLAlgorithm # noqa: E402
from .algorithms.factory import make_algorithm # noqa: E402
from .gym_manipulator import ( # noqa: E402
make_processors,
make_robot_env,
reset_and_build_transition,
step_env_and_process_transition,
)
from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig
from .queue import get_last_item_from_queue # noqa: E402
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
# Main entry point
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate()
display_pid = False
if not use_threads(cfg):
@@ -432,7 +421,7 @@ def act_with_policy(
def establish_learner_connection(
stub: "services_pb2_grpc.LearnerServiceStub",
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Any, # Event
attempts: int = 30,
) -> bool:
@@ -465,7 +454,7 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
@@ -488,8 +477,8 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> None:
"""Receive parameters from the learner.
@@ -542,8 +531,8 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> None:
"""Send transitions to the learner.
@@ -598,8 +587,8 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> None:
"""Send interactions to the learner.
@@ -657,7 +646,7 @@ def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
) -> Generator[Any, None, services_pb2.Empty]:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -676,7 +665,7 @@ def interactions_stream(
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
) -> Generator[Any, None, services_pb2.Empty]:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
+43 -50
View File
@@ -51,31 +51,44 @@ import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
from typing import Any
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.utils.import_utils import require_package
from lerobot.cameras import opencv # noqa: F401
from lerobot.common.train_utils import (
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
import grpc # noqa: E402
import torch # noqa: E402
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402
from safetensors.torch import load_file as load_safetensors # noqa: E402
from termcolor import colored # noqa: E402
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from torch.optim.optimizer import Optimizer # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.common.train_utils import ( # noqa: E402
get_step_checkpoint_dir,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import (
from lerobot.common.wandb_utils import WandBLogger # noqa: E402
from lerobot.configs import parser # noqa: E402
from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.transport import services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
from lerobot.utils.constants import ( # noqa: E402
ACTION,
ALGORITHM_DIR,
CHECKPOINTS_DIR,
@@ -84,46 +97,26 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
from lerobot.utils.io_utils import load_json, write_json # noqa: E402
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.utils import ( # noqa: E402
format_big_number,
init_logging,
)
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
else:
grpc = None
services_pb2_grpc = None
MAX_MESSAGE_SIZE = None
bytes_to_python_object = None
bytes_to_transitions = None
state_to_bytes = None
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer
from .algorithms.base import RLAlgorithm # noqa: E402
from .algorithms.factory import make_algorithm # noqa: E402
from .buffer import ReplayBuffer # noqa: E402
from .data_sources import OnlineOfflineMixer # noqa: E402
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
from .trainer import RLTrainer # noqa: E402
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg):
import torch.multiprocessing as mp
+7 -15
View File
@@ -18,29 +18,22 @@
import logging
import time
from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING or _grpc_available:
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402
from .queue import get_last_item_from_queue
from .queue import get_last_item_from_queue # noqa: E402
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(_ServicerBase):
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -56,7 +49,6 @@ class LearnerService(_ServicerBase):
interaction_message_queue: Queue,
queue_get_timeout: float = 0.001,
):
require_package("grpcio", extra="hilserl", import_name="grpc")
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes