refactor(rl): move grpcio guards to runtime entry points

This commit is contained in:
Khalil Meftah
2026-05-08 11:03:00 +02:00
parent f5a5ca04e2
commit b1b2708e2f
6 changed files with 67 additions and 38 deletions
-3
View File
@@ -333,9 +333,6 @@ ignore = [
"__init__.py" = ["F401", "F403", "E402"] "__init__.py" = ["F401", "F403", "E402"]
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect # E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] "src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/rl/actor.py" = ["E402"]
"src/lerobot/rl/learner.py" = ["E402"]
"src/lerobot/rl/learner_service.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
+28 -15
View File
@@ -46,18 +46,15 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide https://github.com/michel-aractingi/lerobot-hilserl-guide
""" """
from __future__ import annotations
import logging import logging
import os import os
import time import time
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="hilserl", import_name="grpc")
import grpc
import torch import torch
from torch import nn from torch import nn
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
@@ -69,16 +66,8 @@ from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
@@ -91,6 +80,29 @@ from lerobot.utils.utils import (
init_logging, init_logging,
) )
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
from .gym_manipulator import ( from .gym_manipulator import (
make_processors, make_processors,
make_robot_env, make_robot_env,
@@ -105,6 +117,7 @@ from .train_rl import TrainRLServerPipelineConfig
@parser.wrap() @parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig): def actor_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate() cfg.validate()
display_pid = False display_pid = False
if not use_threads(cfg): if not use_threads(cfg):
+1 -2
View File
@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.import_utils import require_package
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -312,8 +313,6 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
# Check if this is a GymHIL simulation environment # Check if this is a GymHIL simulation environment
if cfg.name == "gym_hil": if cfg.name == "gym_hil":
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop" assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
from lerobot.utils.import_utils import require_package
require_package("gym-hil", extra="hilserl", import_name="gym_hil") require_package("gym-hil", extra="hilserl", import_name="gym_hil")
import gym_hil # noqa: F401 import gym_hil # noqa: F401
+23 -13
View File
@@ -44,6 +44,8 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide https://github.com/michel-aractingi/lerobot-hilserl-guide
""" """
from __future__ import annotations
import logging import logging
import os import os
import shutil import shutil
@@ -51,13 +53,8 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="hilserl", import_name="grpc")
import grpc
import torch import torch
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
@@ -78,13 +75,6 @@ from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
@@ -93,6 +83,7 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
) )
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import _grpc_available, require_package
from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
@@ -100,6 +91,24 @@ from lerobot.utils.utils import (
init_logging, init_logging,
) )
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
)
else:
grpc = None
services_pb2_grpc = None
MAX_MESSAGE_SIZE = None
bytes_to_python_object = None
bytes_to_transitions = None
state_to_bytes = None
from .algorithms.base import RLAlgorithm from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer from .buffer import ReplayBuffer
@@ -111,6 +120,7 @@ from .trainer import RLTrainer
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig): def train_cli(cfg: TrainRLServerPipelineConfig):
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
+14 -5
View File
@@ -18,13 +18,21 @@
import logging import logging
import time import time
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available, require_package
require_package("grpcio", extra="hilserl", import_name="grpc") if TYPE_CHECKING or _grpc_available:
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.transport import services_pb2, services_pb2_grpc _ServicerBase = services_pb2_grpc.LearnerServiceServicer
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks else:
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
from .queue import get_last_item_from_queue from .queue import get_last_item_from_queue
@@ -32,7 +40,7 @@ MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10 SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer): class LearnerService(_ServicerBase):
""" """
Implementation of the LearnerService gRPC service Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -48,6 +56,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
interaction_message_queue: Queue, interaction_message_queue: Queue,
queue_get_timeout: float = 0.001, queue_get_timeout: float = 0.001,
): ):
require_package("grpcio", extra="hilserl", import_name="grpc")
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes self.seconds_between_pushes = seconds_between_pushes
+1
View File
@@ -132,6 +132,7 @@ _faker_available = is_package_available("faker")
_pynput_available = is_package_available("pynput") _pynput_available = is_package_available("pynput")
_pygame_available = is_package_available("pygame") _pygame_available = is_package_available("pygame")
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils") _qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
_grpc_available = is_package_available("grpcio", import_name="grpc")
_wallx_deps_available = ( _wallx_deps_available = (
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available _transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
) )