From b1b2708e2f2678e276f8dc7181b1b59157fd4264 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 8 May 2026 11:03:00 +0200 Subject: [PATCH] refactor(rl): move grpcio guards to runtime entry points --- pyproject.toml | 3 --- src/lerobot/rl/actor.py | 43 ++++++++++++++++++++----------- src/lerobot/rl/gym_manipulator.py | 3 +-- src/lerobot/rl/learner.py | 36 ++++++++++++++++---------- src/lerobot/rl/learner_service.py | 19 ++++++++++---- src/lerobot/utils/import_utils.py | 1 + 6 files changed, 67 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f1e2558c..0ae3abd73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -333,9 +333,6 @@ ignore = [ "__init__.py" = ["F401", "F403", "E402"] # E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect "src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] -"src/lerobot/rl/actor.py" = ["E402"] -"src/lerobot/rl/learner.py" = ["E402"] -"src/lerobot/rl/learner_service.py" = ["E402"] "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original [tool.ruff.lint.isort] diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 14e786268..e7820d14f 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -46,18 +46,15 @@ For more details on the complete HILSerl training workflow, see: https://github.com/michel-aractingi/lerobot-hilserl-guide """ +from __future__ import annotations + import logging import os import time from functools import lru_cache from queue import Empty -from typing import Any +from typing import TYPE_CHECKING, Any -from lerobot.utils.import_utils import require_package - -require_package("grpcio", extra="hilserl", import_name="grpc") - -import grpc import torch from torch import nn from torch.multiprocessing import Queue @@ -69,16 +66,8 @@ from lerobot.processor import TransitionKey from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents -from lerobot.transport import services_pb2, services_pb2_grpc -from lerobot.transport.utils import ( - bytes_to_state_dict, - grpc_channel_options, - python_object_to_bytes, - receive_bytes_in_chunks, - send_bytes_in_chunks, - transitions_to_bytes, -) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep @@ -91,6 +80,29 @@ from lerobot.utils.utils import ( init_logging, ) +if TYPE_CHECKING or _grpc_available: + import grpc + + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import ( + bytes_to_state_dict, + grpc_channel_options, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, + ) +else: + grpc = None + services_pb2 = None + services_pb2_grpc = None + bytes_to_state_dict = None + grpc_channel_options = None + python_object_to_bytes = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + transitions_to_bytes = None + from .gym_manipulator import ( make_processors, make_robot_env, @@ -105,6 +117,7 @@ from .train_rl import TrainRLServerPipelineConfig @parser.wrap() def actor_cli(cfg: TrainRLServerPipelineConfig): + require_package("grpcio", extra="hilserl", import_name="grpc") cfg.validate() display_pid = False if not use_threads(cfg): diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index af451bfeb..03f7b4eea 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -74,6 +74,7 @@ from lerobot.teleoperators import ( from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD +from lerobot.utils.import_utils import require_package from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -312,8 +313,6 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]: # Check if this is a GymHIL simulation environment if cfg.name == "gym_hil": assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop" - from lerobot.utils.import_utils import require_package - require_package("gym-hil", extra="hilserl", import_name="gym_hil") import gym_hil # noqa: F401 diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 7d3190230..3caa07387 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -44,6 +44,8 @@ For more details on the complete HILSerl training workflow, see: https://github.com/michel-aractingi/lerobot-hilserl-guide """ +from __future__ import annotations + import logging import os import shutil @@ -51,13 +53,8 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat -from typing import Any +from typing import TYPE_CHECKING, Any -from lerobot.utils.import_utils import require_package - -require_package("grpcio", extra="hilserl", import_name="grpc") - -import grpc import torch from termcolor import colored from torch import nn @@ -78,13 +75,6 @@ from lerobot.policies import make_policy, make_pre_post_processors from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents -from lerobot.transport import services_pb2_grpc -from lerobot.transport.utils import ( - MAX_MESSAGE_SIZE, - bytes_to_python_object, - bytes_to_transitions, - state_to_bytes, -) from lerobot.utils.constants import ( ACTION, CHECKPOINTS_DIR, @@ -93,6 +83,7 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, ) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.import_utils import _grpc_available, require_package from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -100,6 +91,24 @@ from lerobot.utils.utils import ( init_logging, ) +if TYPE_CHECKING or _grpc_available: + import grpc + + from lerobot.transport import services_pb2_grpc + from lerobot.transport.utils import ( + MAX_MESSAGE_SIZE, + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, + ) +else: + grpc = None + services_pb2_grpc = None + MAX_MESSAGE_SIZE = None + bytes_to_python_object = None + bytes_to_transitions = None + state_to_bytes = None + from .algorithms.base import RLAlgorithm from .algorithms.factory import make_algorithm from .buffer import ReplayBuffer @@ -111,6 +120,7 @@ from .trainer import RLTrainer @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): + require_package("grpcio", extra="hilserl", import_name="grpc") if not use_threads(cfg): import torch.multiprocessing as mp diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index 358b01090..65af94038 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -18,13 +18,21 @@ import logging import time from multiprocessing import Event, Queue +from typing import TYPE_CHECKING -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _grpc_available, require_package -require_package("grpcio", extra="hilserl", import_name="grpc") +if TYPE_CHECKING or _grpc_available: + from lerobot.transport import services_pb2, services_pb2_grpc + from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks -from lerobot.transport import services_pb2, services_pb2_grpc -from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks + _ServicerBase = services_pb2_grpc.LearnerServiceServicer +else: + services_pb2 = None + services_pb2_grpc = None + receive_bytes_in_chunks = None + send_bytes_in_chunks = None + _ServicerBase = object from .queue import get_last_item_from_queue @@ -32,7 +40,7 @@ MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 -class LearnerService(services_pb2_grpc.LearnerServiceServicer): +class LearnerService(_ServicerBase): """ Implementation of the LearnerService gRPC service This service is used to send parameters to the Actor and receive transitions and interactions from the Actor @@ -48,6 +56,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer): interaction_message_queue: Queue, queue_get_timeout: float = 0.001, ): + require_package("grpcio", extra="hilserl", import_name="grpc") self.shutdown_event = shutdown_event self.parameters_queue = parameters_queue self.seconds_between_pushes = seconds_between_pushes diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index bfa87fb86..6ba912bf5 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -132,6 +132,7 @@ _faker_available = is_package_available("faker") _pynput_available = is_package_available("pynput") _pygame_available = is_package_available("pygame") _qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils") +_grpc_available = is_package_available("grpcio", import_name="grpc") _wallx_deps_available = ( _transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available )