refactor(rl): hoist grpcio guard to module top in actor/learner

This commit is contained in:
Khalil Meftah
2026-05-09 20:49:53 +02:00
parent a5222d3f1d
commit e7886c2285
3 changed files with 97 additions and 123 deletions
+47 -58
View File
@@ -52,74 +52,63 @@ import time
from collections.abc import Generator from collections.abc import Generator
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from typing import TYPE_CHECKING, Any from typing import Any
import torch from lerobot.utils.import_utils import require_package
from torch import nn
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401 # Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
from lerobot.configs import parser require_package("grpcio", extra="hilserl", import_name="grpc")
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey import grpc # noqa: E402
from lerobot.robots import so_follower # noqa: F401 import torch # noqa: E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from torch import nn # noqa: E402
from lerobot.teleoperators.utils import TeleopEvents from torch.multiprocessing import Queue # noqa: E402
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.utils.process import ProcessSignalHandler from lerobot.configs import parser # noqa: E402
from lerobot.utils.random_utils import set_seed from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.utils.robot_utils import precise_sleep from lerobot.processor import TransitionKey # noqa: E402
from lerobot.utils.transition import ( from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.robot_utils import precise_sleep # noqa: E402
from lerobot.utils.transition import ( # noqa: E402
Transition, Transition,
move_transition_to_device, move_transition_to_device,
) )
from lerobot.utils.utils import ( from lerobot.utils.utils import ( # noqa: E402
TimerManager, TimerManager,
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm from .algorithms.base import RLAlgorithm # noqa: E402
from .algorithms.factory import make_algorithm from .algorithms.factory import make_algorithm # noqa: E402
from .gym_manipulator import ( # noqa: E402
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
from .gym_manipulator import (
make_processors, make_processors,
make_robot_env, make_robot_env,
reset_and_build_transition, reset_and_build_transition,
step_env_and_process_transition, step_env_and_process_transition,
) )
from .queue import get_last_item_from_queue from .queue import get_last_item_from_queue # noqa: E402
from .train_rl import TrainRLServerPipelineConfig from .train_rl import TrainRLServerPipelineConfig # noqa: E402
# Main entry point # Main entry point
@parser.wrap() @parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig): def actor_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate() cfg.validate()
display_pid = False display_pid = False
if not use_threads(cfg): if not use_threads(cfg):
@@ -432,7 +421,7 @@ def act_with_policy(
def establish_learner_connection( def establish_learner_connection(
stub: "services_pb2_grpc.LearnerServiceStub", stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Any, # Event shutdown_event: Any, # Event
attempts: int = 30, attempts: int = 30,
) -> bool: ) -> bool:
@@ -465,7 +454,7 @@ def establish_learner_connection(
def learner_service_client( def learner_service_client(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 50051, port: int = 50051,
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]": ) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""Return a client for the learner service. """Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
@@ -488,8 +477,8 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: "grpc.Channel | None" = None, grpc_channel: grpc.Channel | None = None,
) -> None: ) -> None:
"""Receive parameters from the learner. """Receive parameters from the learner.
@@ -542,8 +531,8 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue, transitions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: "grpc.Channel | None" = None, grpc_channel: grpc.Channel | None = None,
) -> None: ) -> None:
"""Send transitions to the learner. """Send transitions to the learner.
@@ -598,8 +587,8 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: "grpc.Channel | None" = None, grpc_channel: grpc.Channel | None = None,
) -> None: ) -> None:
"""Send interactions to the learner. """Send interactions to the learner.
@@ -657,7 +646,7 @@ def transitions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
transitions_queue: Queue, transitions_queue: Queue,
timeout: float, timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]": ) -> Generator[Any, None, services_pb2.Empty]:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=timeout) message = transitions_queue.get(block=True, timeout=timeout)
@@ -676,7 +665,7 @@ def interactions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
interactions_queue: Queue, interactions_queue: Queue,
timeout: float, timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]": ) -> Generator[Any, None, services_pb2.Empty]:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = interactions_queue.get(block=True, timeout=timeout) message = interactions_queue.get(block=True, timeout=timeout)
+43 -50
View File
@@ -51,31 +51,44 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import TYPE_CHECKING, Any from typing import Any
import torch from lerobot.utils.import_utils import require_package
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401 # Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
from lerobot.common.train_utils import ( require_package("grpcio", extra="hilserl", import_name="grpc")
import grpc # noqa: E402
import torch # noqa: E402
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402
from safetensors.torch import load_file as load_safetensors # noqa: E402
from termcolor import colored # noqa: E402
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from torch.optim.optimizer import Optimizer # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.common.train_utils import ( # noqa: E402
get_step_checkpoint_dir, get_step_checkpoint_dir,
load_training_state as utils_load_training_state, load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint, update_last_checkpoint,
) )
from lerobot.common.wandb_utils import WandBLogger from lerobot.common.wandb_utils import WandBLogger # noqa: E402
from lerobot.configs import parser from lerobot.configs import parser # noqa: E402
from lerobot.datasets import LeRobotDataset, make_dataset from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402
from lerobot.policies import make_policy, make_pre_post_processors from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.utils.constants import ( from lerobot.transport import services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
from lerobot.utils.constants import ( # noqa: E402
ACTION, ACTION,
ALGORITHM_DIR, ALGORITHM_DIR,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
@@ -84,46 +97,26 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
TRAINING_STEP, TRAINING_STEP,
) )
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.utils.io_utils import load_json, write_json # noqa: E402
from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( # noqa: E402
from lerobot.utils.utils import (
format_big_number, format_big_number,
init_logging, init_logging,
) )
if TYPE_CHECKING or _grpc_available: from .algorithms.base import RLAlgorithm # noqa: E402
import grpc from .algorithms.factory import make_algorithm # noqa: E402
from .buffer import ReplayBuffer # noqa: E402
from lerobot.transport import services_pb2_grpc from .data_sources import OnlineOfflineMixer # noqa: E402
from lerobot.transport.utils import ( from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402
MAX_MESSAGE_SIZE, from .train_rl import TrainRLServerPipelineConfig # noqa: E402
bytes_to_python_object, from .trainer import RLTrainer # noqa: E402
bytes_to_transitions,
state_to_bytes,
)
else:
grpc = None
services_pb2_grpc = None
MAX_MESSAGE_SIZE = None
bytes_to_python_object = None
bytes_to_transitions = None
state_to_bytes = None
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig): def train_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
+7 -15
View File
@@ -18,29 +18,22 @@
import logging import logging
import time import time
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.utils.import_utils import require_package
if TYPE_CHECKING or _grpc_available: # Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
from lerobot.transport import services_pb2, services_pb2_grpc require_package("grpcio", extra="hilserl", import_name="grpc")
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
else: from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
from .queue import get_last_item_from_queue from .queue import get_last_item_from_queue # noqa: E402
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10 SHUTDOWN_TIMEOUT = 10
class LearnerService(_ServicerBase): class LearnerService(services_pb2_grpc.LearnerServiceServicer):
""" """
Implementation of the LearnerService gRPC service Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -56,7 +49,6 @@ class LearnerService(_ServicerBase):
interaction_message_queue: Queue, interaction_message_queue: Queue,
queue_get_timeout: float = 0.001, queue_get_timeout: float = 0.001,
): ):
require_package("grpcio", extra="hilserl", import_name="grpc")
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes self.seconds_between_pushes = seconds_between_pushes