refactor(rl): update type hints for learner and actor functions

This commit is contained in:
Khalil Meftah
2026-05-09 16:46:14 +02:00
parent ef927ac830
commit a5222d3f1d
2 changed files with 37 additions and 24 deletions
+37 -22
View File
@@ -46,11 +46,10 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
from __future__ import annotations
import logging
import os
import time
from collections.abc import Generator
from functools import lru_cache
from queue import Empty
from typing import TYPE_CHECKING, Any
@@ -433,10 +432,10 @@ def act_with_policy(
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event
attempts: int = 30,
):
) -> bool:
"""Establish a connection with the learner.
Args:
@@ -466,12 +465,14 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""
Returns a client for the learner service.
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
Returns:
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
"""
channel = grpc.insecure_channel(
@@ -487,15 +488,17 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
@@ -539,11 +542,10 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send transitions to the learner.
This function continuously retrieves messages from the queue and processes:
@@ -551,6 +553,13 @@ def send_transitions(
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
transitions_queue (Queue): The queue to receive the transitions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -589,17 +598,23 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
interactions_queue (Queue): The queue to receive the interactions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -642,7 +657,7 @@ def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> services_pb2.Empty:
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -661,7 +676,7 @@ def interactions_stream(
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float,
) -> services_pb2.Empty:
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
-2
View File
@@ -44,8 +44,6 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
from __future__ import annotations
import logging
import os
import shutil