diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 57883fa37..e35b9ade3 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -52,74 +52,63 @@ import time from collections.abc import Generator from functools import lru_cache from queue import Empty -from typing import TYPE_CHECKING, Any +from typing import Any -import torch -from torch import nn -from torch.multiprocessing import Queue +from lerobot.utils.import_utils import require_package -from lerobot.cameras import opencv # noqa: F401 -from lerobot.configs import parser -from lerobot.policies import make_policy, make_pre_post_processors -from lerobot.processor import TransitionKey -from lerobot.robots import so_follower # noqa: F401 -from lerobot.teleoperators import gamepad, so_leader # noqa: F401 -from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.device_utils import get_safe_torch_device -from lerobot.utils.import_utils import _grpc_available, require_package -from lerobot.utils.process import ProcessSignalHandler -from lerobot.utils.random_utils import set_seed -from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.transition import ( +# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. +require_package("grpcio", extra="hilserl", import_name="grpc") + +import grpc # noqa: E402 +import torch # noqa: E402 +from torch import nn # noqa: E402 +from torch.multiprocessing import Queue # noqa: E402 + +from lerobot.cameras import opencv # noqa: F401, E402 +from lerobot.configs import parser # noqa: E402 +from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 +from lerobot.processor import TransitionKey # noqa: E402 +from lerobot.robots import so_follower # noqa: F401, E402 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 +from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 +from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 +from lerobot.transport.utils import ( # noqa: E402 + bytes_to_state_dict, + grpc_channel_options, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 +from lerobot.utils.process import ProcessSignalHandler # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 +from lerobot.utils.robot_utils import precise_sleep # noqa: E402 +from lerobot.utils.transition import ( # noqa: E402 Transition, move_transition_to_device, ) -from lerobot.utils.utils import ( +from lerobot.utils.utils import ( # noqa: E402 TimerManager, init_logging, ) -from .algorithms.base import RLAlgorithm -from .algorithms.factory import make_algorithm - -if TYPE_CHECKING or _grpc_available: - import grpc - - from lerobot.transport import services_pb2, services_pb2_grpc - from lerobot.transport.utils import ( - bytes_to_state_dict, - grpc_channel_options, - python_object_to_bytes, - receive_bytes_in_chunks, - send_bytes_in_chunks, - transitions_to_bytes, - ) -else: - grpc = None - services_pb2 = None - services_pb2_grpc = None - bytes_to_state_dict = None - grpc_channel_options = None - python_object_to_bytes = None - receive_bytes_in_chunks = None - send_bytes_in_chunks = None - transitions_to_bytes = None - -from .gym_manipulator import ( +from .algorithms.base import RLAlgorithm # noqa: E402 +from .algorithms.factory import make_algorithm # noqa: E402 +from .gym_manipulator import ( # noqa: E402 make_processors, make_robot_env, reset_and_build_transition, step_env_and_process_transition, ) -from .queue import get_last_item_from_queue -from .train_rl import TrainRLServerPipelineConfig +from .queue import get_last_item_from_queue # noqa: E402 +from .train_rl import TrainRLServerPipelineConfig # noqa: E402 # Main entry point @parser.wrap() def actor_cli(cfg: TrainRLServerPipelineConfig): - require_package("grpcio", extra="hilserl", import_name="grpc") cfg.validate() display_pid = False if not use_threads(cfg): @@ -432,7 +421,7 @@ def act_with_policy( def establish_learner_connection( - stub: "services_pb2_grpc.LearnerServiceStub", + stub: services_pb2_grpc.LearnerServiceStub, shutdown_event: Any, # Event attempts: int = 30, ) -> bool: @@ -465,7 +454,7 @@ def establish_learner_connection( def learner_service_client( host: str = "127.0.0.1", port: int = 50051, -) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]": +) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: """Return a client for the learner service. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. @@ -488,8 +477,8 @@ def receive_policy( cfg: TrainRLServerPipelineConfig, parameters_queue: Queue, shutdown_event: Any, # Event - learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, - grpc_channel: "grpc.Channel | None" = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, ) -> None: """Receive parameters from the learner. @@ -542,8 +531,8 @@ def send_transitions( cfg: TrainRLServerPipelineConfig, transitions_queue: Queue, shutdown_event: Any, # Event - learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, - grpc_channel: "grpc.Channel | None" = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, ) -> None: """Send transitions to the learner. @@ -598,8 +587,8 @@ def send_interactions( cfg: TrainRLServerPipelineConfig, interactions_queue: Queue, shutdown_event: Any, # Event - learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, - grpc_channel: "grpc.Channel | None" = None, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, ) -> None: """Send interactions to the learner. @@ -657,7 +646,7 @@ def transitions_stream( shutdown_event: Any, # Event transitions_queue: Queue, timeout: float, -) -> "Generator[Any, None, services_pb2.Empty]": +) -> Generator[Any, None, services_pb2.Empty]: while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=timeout) @@ -676,7 +665,7 @@ def interactions_stream( shutdown_event: Any, # Event interactions_queue: Queue, timeout: float, -) -> "Generator[Any, None, services_pb2.Empty]": +) -> Generator[Any, None, services_pb2.Empty]: while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=timeout) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 6b3b620a7..0f89dc439 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -51,31 +51,44 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat -from typing import TYPE_CHECKING, Any +from typing import Any -import torch -from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE -from safetensors.torch import load_file as load_safetensors -from termcolor import colored -from torch import nn -from torch.multiprocessing import Queue -from torch.optim.optimizer import Optimizer +from lerobot.utils.import_utils import require_package -from lerobot.cameras import opencv # noqa: F401 -from lerobot.common.train_utils import ( +# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. +require_package("grpcio", extra="hilserl", import_name="grpc") + +import grpc # noqa: E402 +import torch # noqa: E402 +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402 +from safetensors.torch import load_file as load_safetensors # noqa: E402 +from termcolor import colored # noqa: E402 +from torch import nn # noqa: E402 +from torch.multiprocessing import Queue # noqa: E402 +from torch.optim.optimizer import Optimizer # noqa: E402 + +from lerobot.cameras import opencv # noqa: F401, E402 +from lerobot.common.train_utils import ( # noqa: E402 get_step_checkpoint_dir, load_training_state as utils_load_training_state, save_checkpoint, update_last_checkpoint, ) -from lerobot.common.wandb_utils import WandBLogger -from lerobot.configs import parser -from lerobot.datasets import LeRobotDataset, make_dataset -from lerobot.policies import make_policy, make_pre_post_processors -from lerobot.robots import so_follower # noqa: F401 -from lerobot.teleoperators import gamepad, so_leader # noqa: F401 -from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.constants import ( +from lerobot.common.wandb_utils import WandBLogger # noqa: E402 +from lerobot.configs import parser # noqa: E402 +from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402 +from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 +from lerobot.robots import so_follower # noqa: F401, E402 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 +from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 +from lerobot.transport import services_pb2_grpc # noqa: E402 +from lerobot.transport.utils import ( # noqa: E402 + MAX_MESSAGE_SIZE, + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, +) +from lerobot.utils.constants import ( # noqa: E402 ACTION, ALGORITHM_DIR, CHECKPOINTS_DIR, @@ -84,46 +97,26 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, TRAINING_STEP, ) -from lerobot.utils.device_utils import get_safe_torch_device -from lerobot.utils.import_utils import _grpc_available, require_package -from lerobot.utils.io_utils import load_json, write_json -from lerobot.utils.process import ProcessSignalHandler -from lerobot.utils.random_utils import set_seed -from lerobot.utils.utils import ( +from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 +from lerobot.utils.io_utils import load_json, write_json # noqa: E402 +from lerobot.utils.process import ProcessSignalHandler # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 +from lerobot.utils.utils import ( # noqa: E402 format_big_number, init_logging, ) -if TYPE_CHECKING or _grpc_available: - import grpc - - from lerobot.transport import services_pb2_grpc - from lerobot.transport.utils import ( - MAX_MESSAGE_SIZE, - bytes_to_python_object, - bytes_to_transitions, - state_to_bytes, - ) -else: - grpc = None - services_pb2_grpc = None - MAX_MESSAGE_SIZE = None - bytes_to_python_object = None - bytes_to_transitions = None - state_to_bytes = None - -from .algorithms.base import RLAlgorithm -from .algorithms.factory import make_algorithm -from .buffer import ReplayBuffer -from .data_sources import OnlineOfflineMixer -from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService -from .train_rl import TrainRLServerPipelineConfig -from .trainer import RLTrainer +from .algorithms.base import RLAlgorithm # noqa: E402 +from .algorithms.factory import make_algorithm # noqa: E402 +from .buffer import ReplayBuffer # noqa: E402 +from .data_sources import OnlineOfflineMixer # noqa: E402 +from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402 +from .train_rl import TrainRLServerPipelineConfig # noqa: E402 +from .trainer import RLTrainer # noqa: E402 @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): - require_package("grpcio", extra="hilserl", import_name="grpc") if not use_threads(cfg): import torch.multiprocessing as mp diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index 65af94038..a95eb37a6 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -18,29 +18,22 @@ import logging import time from multiprocessing import Event, Queue -from typing import TYPE_CHECKING -from lerobot.utils.import_utils import _grpc_available, require_package +from lerobot.utils.import_utils import require_package -if TYPE_CHECKING or _grpc_available: - from lerobot.transport import services_pb2, services_pb2_grpc - from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. +require_package("grpcio", extra="hilserl", import_name="grpc") - _ServicerBase = services_pb2_grpc.LearnerServiceServicer -else: - services_pb2 = None - services_pb2_grpc = None - receive_bytes_in_chunks = None - send_bytes_in_chunks = None - _ServicerBase = object +from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 +from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402 -from .queue import get_last_item_from_queue +from .queue import get_last_item_from_queue # noqa: E402 MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 -class LearnerService(_ServicerBase): +class LearnerService(services_pb2_grpc.LearnerServiceServicer): """ Implementation of the LearnerService gRPC service This service is used to send parameters to the Actor and receive transitions and interactions from the Actor @@ -56,7 +49,6 @@ class LearnerService(_ServicerBase): interaction_message_queue: Queue, queue_get_timeout: float = 0.001, ): - require_package("grpcio", extra="hilserl", import_name="grpc") self.shutdown_event = shutdown_event self.parameters_queue = parameters_queue self.seconds_between_pushes = seconds_between_pushes