From d28316d1b6e7e38071627f2786a5407244a409d6 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 11 May 2026 14:05:31 +0200 Subject: [PATCH] chore(rl): manage import pattern in actor (#3564) * chore(rl): manage import pattern in actor * chore(rl): optional grpc imports in learner; quote grpc ServicerContext types --------- Co-authored-by: Khalil Meftah --- pyproject.toml | 2 +- src/lerobot/rl/actor.py | 102 +++++++++++++++++------------- src/lerobot/rl/learner.py | 77 +++++++++++----------- src/lerobot/rl/learner_service.py | 34 ++++++---- 4 files changed, 122 insertions(+), 93 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ae3abd73..e36525bdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,7 +195,7 @@ groot = [ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index e35b9ade3..bfc7f1882 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -52,63 +52,75 @@ import time from collections.abc import Generator from functools import lru_cache from queue import Empty -from typing import Any +from typing import TYPE_CHECKING, Any -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _grpc_available, require_package -# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. -require_package("grpcio", extra="hilserl", import_name="grpc") +if TYPE_CHECKING or _grpc_available: + import grpc -import grpc # noqa: E402 -import torch # noqa: E402 -from torch import nn # noqa: E402 -from torch.multiprocessing import Queue # noqa: E402 + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import ( + bytes_to_state_dict, + grpc_channel_options, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, + ) +else: + grpc = None + services_pb2 = None + services_pb2_grpc = None + bytes_to_state_dict = None + grpc_channel_options = None + python_object_to_bytes = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + transitions_to_bytes = None -from lerobot.cameras import opencv # noqa: F401, E402 -from lerobot.configs import parser # noqa: E402 -from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 -from lerobot.processor import TransitionKey # noqa: E402 -from lerobot.robots import so_follower # noqa: F401, E402 -from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 -from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 -from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 -from lerobot.transport.utils import ( # noqa: E402 - bytes_to_state_dict, - grpc_channel_options, - python_object_to_bytes, - receive_bytes_in_chunks, - send_bytes_in_chunks, - transitions_to_bytes, -) -from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 -from lerobot.utils.process import ProcessSignalHandler # noqa: E402 -from lerobot.utils.random_utils import set_seed # noqa: E402 -from lerobot.utils.robot_utils import precise_sleep # noqa: E402 -from lerobot.utils.transition import ( # noqa: E402 +import torch +from torch import nn +from torch.multiprocessing import Queue + +from lerobot.cameras import opencv # noqa: F401 +from lerobot.configs import parser +from lerobot.policies import make_policy, make_pre_post_processors +from lerobot.processor import TransitionKey +from lerobot.robots import so_follower # noqa: F401 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.random_utils import set_seed +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.transition import ( Transition, move_transition_to_device, ) -from lerobot.utils.utils import ( # noqa: E402 +from lerobot.utils.utils import ( TimerManager, init_logging, ) -from .algorithms.base import RLAlgorithm # noqa: E402 -from .algorithms.factory import make_algorithm # noqa: E402 -from .gym_manipulator import ( # noqa: E402 +from .algorithms.base import RLAlgorithm +from .algorithms.factory import make_algorithm +from .gym_manipulator import ( make_processors, make_robot_env, reset_and_build_transition, step_env_and_process_transition, ) -from .queue import get_last_item_from_queue # noqa: E402 -from .train_rl import TrainRLServerPipelineConfig # noqa: E402 +from .queue import get_last_item_from_queue +from .train_rl import TrainRLServerPipelineConfig # Main entry point @parser.wrap() def actor_cli(cfg: TrainRLServerPipelineConfig): + # Fail fast with a friendly error if the optional ``hilserl`` extra is missing. + require_package("grpcio", extra="hilserl", import_name="grpc") cfg.validate() display_pid = False if not use_threads(cfg): @@ -421,7 +433,7 @@ def act_with_policy( def establish_learner_connection( - stub: services_pb2_grpc.LearnerServiceStub, + stub: "services_pb2_grpc.LearnerServiceStub", shutdown_event: Any, # Event attempts: int = 30, ) -> bool: @@ -454,7 +466,7 @@ def establish_learner_connection( def learner_service_client( host: str = "127.0.0.1", port: int = 50051, -) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: +) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]": """Return a client for the learner service. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. @@ -477,8 +489,8 @@ def receive_policy( cfg: TrainRLServerPipelineConfig, parameters_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, ) -> None: """Receive parameters from the learner. @@ -531,8 +543,8 @@ def send_transitions( cfg: TrainRLServerPipelineConfig, transitions_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, ) -> None: """Send transitions to the learner. @@ -587,8 +599,8 @@ def send_interactions( cfg: TrainRLServerPipelineConfig, interactions_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, ) -> None: """Send interactions to the learner. @@ -646,7 +658,7 @@ def transitions_stream( shutdown_event: Any, # Event transitions_queue: Queue, timeout: float, -) -> Generator[Any, None, services_pb2.Empty]: +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=timeout) @@ -665,7 +677,7 @@ def interactions_stream( shutdown_event: Any, # Event interactions_queue: Queue, timeout: float, -) -> Generator[Any, None, services_pb2.Empty]: +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=timeout) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 0f89dc439..41cfd8c03 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -51,44 +51,47 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat -from typing import Any +from typing import TYPE_CHECKING, Any -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _grpc_available, require_package -# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. -require_package("grpcio", extra="hilserl", import_name="grpc") +if TYPE_CHECKING or _grpc_available: + import grpc -import grpc # noqa: E402 -import torch # noqa: E402 -from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402 -from safetensors.torch import load_file as load_safetensors # noqa: E402 -from termcolor import colored # noqa: E402 -from torch import nn # noqa: E402 -from torch.multiprocessing import Queue # noqa: E402 -from torch.optim.optimizer import Optimizer # noqa: E402 + from lerobot.transport import services_pb2_grpc +else: + grpc = None + services_pb2_grpc = None -from lerobot.cameras import opencv # noqa: F401, E402 -from lerobot.common.train_utils import ( # noqa: E402 +import torch +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from safetensors.torch import load_file as load_safetensors +from termcolor import colored +from torch import nn +from torch.multiprocessing import Queue +from torch.optim.optimizer import Optimizer + +from lerobot.cameras import opencv # noqa: F401 +from lerobot.common.train_utils import ( get_step_checkpoint_dir, load_training_state as utils_load_training_state, save_checkpoint, update_last_checkpoint, ) -from lerobot.common.wandb_utils import WandBLogger # noqa: E402 -from lerobot.configs import parser # noqa: E402 -from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402 -from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 -from lerobot.robots import so_follower # noqa: F401, E402 -from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 -from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 -from lerobot.transport import services_pb2_grpc # noqa: E402 -from lerobot.transport.utils import ( # noqa: E402 +from lerobot.common.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.datasets import LeRobotDataset, make_dataset +from lerobot.policies import make_policy, make_pre_post_processors +from lerobot.robots import so_follower # noqa: F401 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents +from lerobot.transport.utils import ( MAX_MESSAGE_SIZE, bytes_to_python_object, bytes_to_transitions, state_to_bytes, ) -from lerobot.utils.constants import ( # noqa: E402 +from lerobot.utils.constants import ( ACTION, ALGORITHM_DIR, CHECKPOINTS_DIR, @@ -97,26 +100,28 @@ from lerobot.utils.constants import ( # noqa: E402 TRAINING_STATE_DIR, TRAINING_STEP, ) -from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 -from lerobot.utils.io_utils import load_json, write_json # noqa: E402 -from lerobot.utils.process import ProcessSignalHandler # noqa: E402 -from lerobot.utils.random_utils import set_seed # noqa: E402 -from lerobot.utils.utils import ( # noqa: E402 +from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.io_utils import load_json, write_json +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.random_utils import set_seed +from lerobot.utils.utils import ( format_big_number, init_logging, ) -from .algorithms.base import RLAlgorithm # noqa: E402 -from .algorithms.factory import make_algorithm # noqa: E402 -from .buffer import ReplayBuffer # noqa: E402 -from .data_sources import OnlineOfflineMixer # noqa: E402 -from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402 -from .train_rl import TrainRLServerPipelineConfig # noqa: E402 -from .trainer import RLTrainer # noqa: E402 +from .algorithms.base import RLAlgorithm +from .algorithms.factory import make_algorithm +from .buffer import ReplayBuffer +from .data_sources import OnlineOfflineMixer +from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService +from .train_rl import TrainRLServerPipelineConfig +from .trainer import RLTrainer @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): + # Fail fast with a friendly error if the optional ``hilserl`` extra is missing. + require_package("grpcio", extra="hilserl", import_name="grpc") if not use_threads(cfg): import torch.multiprocessing as mp diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index a95eb37a6..7a4df7136 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -18,22 +18,32 @@ import logging import time from multiprocessing import Event, Queue +from typing import TYPE_CHECKING -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _grpc_available -# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. -require_package("grpcio", extra="hilserl", import_name="grpc") +from .queue import get_last_item_from_queue -from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 -from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402 +if TYPE_CHECKING or _grpc_available: + import grpc -from .queue import get_last_item_from_queue # noqa: E402 + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks + + _ServicerBase = services_pb2_grpc.LearnerServiceServicer +else: + grpc = None + services_pb2 = None + services_pb2_grpc = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + _ServicerBase = object MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 -class LearnerService(services_pb2_grpc.LearnerServiceServicer): +class LearnerService(_ServicerBase): """ Implementation of the LearnerService gRPC service This service is used to send parameters to the Actor and receive transitions and interactions from the Actor @@ -56,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): self.interaction_message_queue = interaction_message_queue self.queue_get_timeout = queue_get_timeout - def StreamParameters(self, request, context): # noqa: N802 + def StreamParameters( # noqa: N802 + self, request: "services_pb2.Empty", context: "grpc.ServicerContext" + ): # TODO: authorize the request logging.info("[LEARNER] Received request to stream parameters from the Actor") @@ -91,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.info("[LEARNER] Stream parameters finished") return services_pb2.Empty() - def SendTransitions(self, request_iterator, _context): # noqa: N802 + def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802 # TODO: authorize the request logging.info("[LEARNER] Received request to receive transitions from the Actor") @@ -105,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.debug("[LEARNER] Finished receiving transitions") return services_pb2.Empty() - def SendInteractions(self, request_iterator, _context): # noqa: N802 + def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802 # TODO: authorize the request logging.info("[LEARNER] Received request to receive interactions from the Actor") @@ -119,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): logging.debug("[LEARNER] Finished receiving interactions") return services_pb2.Empty() - def Ready(self, request, context): # noqa: N802 + def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802 return services_pb2.Empty()