chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
Steven Palma
2026-05-11 14:05:31 +02:00
committed by GitHub
parent e7886c2285
commit d28316d1b6
4 changed files with 122 additions and 93 deletions
+1 -1
View File
@@ -195,7 +195,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features # Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
+51 -39
View File
@@ -52,27 +52,15 @@ import time
from collections.abc import Generator from collections.abc import Generator
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available, require_package
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. if TYPE_CHECKING or _grpc_available:
require_package("grpcio", extra="hilserl", import_name="grpc") import grpc
import grpc # noqa: E402 from lerobot.transport import services_pb2, services_pb2_grpc
import torch # noqa: E402 from lerobot.transport.utils import (
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.configs import parser # noqa: E402
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.processor import TransitionKey # noqa: E402
from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
bytes_to_state_dict, bytes_to_state_dict,
grpc_channel_options, grpc_channel_options,
python_object_to_bytes, python_object_to_bytes,
@@ -80,35 +68,59 @@ from lerobot.transport.utils import ( # noqa: E402
send_bytes_in_chunks, send_bytes_in_chunks,
transitions_to_bytes, transitions_to_bytes,
) )
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 else:
from lerobot.utils.process import ProcessSignalHandler # noqa: E402 grpc = None
from lerobot.utils.random_utils import set_seed # noqa: E402 services_pb2 = None
from lerobot.utils.robot_utils import precise_sleep # noqa: E402 services_pb2_grpc = None
from lerobot.utils.transition import ( # noqa: E402 bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
import torch
from torch import nn
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition, Transition,
move_transition_to_device, move_transition_to_device,
) )
from lerobot.utils.utils import ( # noqa: E402 from lerobot.utils.utils import (
TimerManager, TimerManager,
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm # noqa: E402 from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm # noqa: E402 from .algorithms.factory import make_algorithm
from .gym_manipulator import ( # noqa: E402 from .gym_manipulator import (
make_processors, make_processors,
make_robot_env, make_robot_env,
reset_and_build_transition, reset_and_build_transition,
step_env_and_process_transition, step_env_and_process_transition,
) )
from .queue import get_last_item_from_queue # noqa: E402 from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig # noqa: E402 from .train_rl import TrainRLServerPipelineConfig
# Main entry point # Main entry point
@parser.wrap() @parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig): def actor_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate() cfg.validate()
display_pid = False display_pid = False
if not use_threads(cfg): if not use_threads(cfg):
@@ -421,7 +433,7 @@ def act_with_policy(
def establish_learner_connection( def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub, stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event shutdown_event: Any, # Event
attempts: int = 30, attempts: int = 30,
) -> bool: ) -> bool:
@@ -454,7 +466,7 @@ def establish_learner_connection(
def learner_service_client( def learner_service_client(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 50051, port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: ) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service. """Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
@@ -477,8 +489,8 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Receive parameters from the learner. """Receive parameters from the learner.
@@ -531,8 +543,8 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue, transitions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Send transitions to the learner. """Send transitions to the learner.
@@ -587,8 +599,8 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Send interactions to the learner. """Send interactions to the learner.
@@ -646,7 +658,7 @@ def transitions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
transitions_queue: Queue, transitions_queue: Queue,
timeout: float, timeout: float,
) -> Generator[Any, None, services_pb2.Empty]: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=timeout) message = transitions_queue.get(block=True, timeout=timeout)
@@ -665,7 +677,7 @@ def interactions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
interactions_queue: Queue, interactions_queue: Queue,
timeout: float, timeout: float,
) -> Generator[Any, None, services_pb2.Empty]: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = interactions_queue.get(block=True, timeout=timeout) message = interactions_queue.get(block=True, timeout=timeout)
+41 -36
View File
@@ -51,44 +51,47 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available, require_package
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. if TYPE_CHECKING or _grpc_available:
require_package("grpcio", extra="hilserl", import_name="grpc") import grpc
import grpc # noqa: E402 from lerobot.transport import services_pb2_grpc
import torch # noqa: E402 else:
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402 grpc = None
from safetensors.torch import load_file as load_safetensors # noqa: E402 services_pb2_grpc = None
from termcolor import colored # noqa: E402
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from torch.optim.optimizer import Optimizer # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402 import torch
from lerobot.common.train_utils import ( # noqa: E402 from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
from lerobot.common.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,
load_training_state as utils_load_training_state, load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint, update_last_checkpoint,
) )
from lerobot.common.wandb_utils import WandBLogger # noqa: E402 from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser # noqa: E402 from lerobot.configs import parser
from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402 from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401, E402 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc # noqa: E402 from lerobot.transport.utils import (
from lerobot.transport.utils import ( # noqa: E402
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
bytes_to_python_object, bytes_to_python_object,
bytes_to_transitions, bytes_to_transitions,
state_to_bytes, state_to_bytes,
) )
from lerobot.utils.constants import ( # noqa: E402 from lerobot.utils.constants import (
ACTION, ACTION,
ALGORITHM_DIR, ALGORITHM_DIR,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
@@ -97,26 +100,28 @@ from lerobot.utils.constants import ( # noqa: E402
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
TRAINING_STEP, TRAINING_STEP,
) )
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.io_utils import load_json, write_json # noqa: E402 from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler # noqa: E402 from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed # noqa: E402 from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( # noqa: E402 from lerobot.utils.utils import (
format_big_number, format_big_number,
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm # noqa: E402 from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm # noqa: E402 from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer # noqa: E402 from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer # noqa: E402 from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402 from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig # noqa: E402 from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer # noqa: E402 from .trainer import RLTrainer
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig): def train_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
+23 -11
View File
@@ -18,22 +18,32 @@
import logging import logging
import time import time
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. from .queue import get_last_item_from_queue
require_package("grpcio", extra="hilserl", import_name="grpc")
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 if TYPE_CHECKING or _grpc_available:
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402 import grpc
from .queue import get_last_item_from_queue # noqa: E402 from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10 SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer): class LearnerService(_ServicerBase):
""" """
Implementation of the LearnerService gRPC service Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -56,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
self.interaction_message_queue = interaction_message_queue self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802 def StreamParameters( # noqa: N802
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
):
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor") logging.info("[LEARNER] Received request to stream parameters from the Actor")
@@ -91,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.info("[LEARNER] Stream parameters finished") logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty() return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802 def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor") logging.info("[LEARNER] Received request to receive transitions from the Actor")
@@ -105,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving transitions") logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty() return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802 def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor") logging.info("[LEARNER] Received request to receive interactions from the Actor")
@@ -119,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving interactions") logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty() return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802 def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
return services_pb2.Empty() return services_pb2.Empty()