From a5222d3f1db387f269614efc06c6d87409ca1827 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Sat, 9 May 2026 16:46:14 +0200 Subject: [PATCH] refactor(rl): update type hints for learner and actor functions --- src/lerobot/rl/actor.py | 59 ++++++++++++++++++++++++--------------- src/lerobot/rl/learner.py | 2 -- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index c553abd12..57883fa37 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -46,11 +46,10 @@ For more details on the complete HILSerl training workflow, see: https://github.com/michel-aractingi/lerobot-hilserl-guide """ -from __future__ import annotations - import logging import os import time +from collections.abc import Generator from functools import lru_cache from queue import Empty from typing import TYPE_CHECKING, Any @@ -433,10 +432,10 @@ def act_with_policy( def establish_learner_connection( - stub: services_pb2_grpc.LearnerServiceStub, + stub: "services_pb2_grpc.LearnerServiceStub", shutdown_event: Any, # Event attempts: int = 30, -): +) -> bool: """Establish a connection with the learner. Args: @@ -466,12 +465,14 @@ def establish_learner_connection( def learner_service_client( host: str = "127.0.0.1", port: int = 50051, -) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: - """ - Returns a client for the learner service. +) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]": + """Return a client for the learner service. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. So we need to create only one client and reuse it. + + Returns: + tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel. """ channel = grpc.insecure_channel( @@ -487,15 +488,17 @@ def receive_policy( cfg: TrainRLServerPipelineConfig, parameters_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -): + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: """Receive parameters from the learner. Args: cfg (TrainRLServerPipelineConfig): The configuration for the actor. parameters_queue (Queue): The queue to receive the parameters. shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ logging.info("[ACTOR] Start receiving parameters from the Learner") if not use_threads(cfg): @@ -539,11 +542,10 @@ def send_transitions( cfg: TrainRLServerPipelineConfig, transitions_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -) -> services_pb2.Empty: - """ - Sends transitions to the learner. + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: + """Send transitions to the learner. This function continuously retrieves messages from the queue and processes: @@ -551,6 +553,13 @@ def send_transitions( - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + transitions_queue (Queue): The queue to receive the transitions. + shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ if not use_threads(cfg): @@ -589,17 +598,23 @@ def send_interactions( cfg: TrainRLServerPipelineConfig, interactions_queue: Queue, shutdown_event: Any, # Event - learner_client: services_pb2_grpc.LearnerServiceStub | None = None, - grpc_channel: grpc.Channel | None = None, -) -> services_pb2.Empty: - """ - Sends interactions to the learner. + learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None, + grpc_channel: "grpc.Channel | None" = None, +) -> None: + """Send interactions to the learner. This function continuously retrieves messages from the queue and processes: - Interaction Messages: - Contains useful statistics about episodic rewards and policy timings. - The message is serialized using `pickle` and sent to the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + interactions_queue (Queue): The queue to receive the interactions. + shutdown_event (Event): The event to check if the process should shutdown. + learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub. + grpc_channel (grpc.Channel | None): Optional pre-created channel. """ if not use_threads(cfg): @@ -642,7 +657,7 @@ def transitions_stream( shutdown_event: Any, # Event transitions_queue: Queue, timeout: float, -) -> services_pb2.Empty: +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=timeout) @@ -661,7 +676,7 @@ def interactions_stream( shutdown_event: Any, # Event interactions_queue: Queue, timeout: float, -) -> services_pb2.Empty: +) -> "Generator[Any, None, services_pb2.Empty]": while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=timeout) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index f41d9d602..6b3b620a7 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -44,8 +44,6 @@ For more details on the complete HILSerl training workflow, see: https://github.com/michel-aractingi/lerobot-hilserl-guide """ -from __future__ import annotations - import logging import os import shutil