refactor(rl): update type hints for learner and actor functions

This commit is contained in:
Khalil Meftah
2026-05-09 16:46:14 +02:00
parent ef927ac830
commit a5222d3f1d
2 changed files with 37 additions and 24 deletions
+37 -22
View File
@@ -46,11 +46,10 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide https://github.com/michel-aractingi/lerobot-hilserl-guide
""" """
from __future__ import annotations
import logging import logging
import os import os
import time import time
from collections.abc import Generator
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -433,10 +432,10 @@ def act_with_policy(
def establish_learner_connection( def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub, stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event shutdown_event: Any, # Event
attempts: int = 30, attempts: int = 30,
): ) -> bool:
"""Establish a connection with the learner. """Establish a connection with the learner.
Args: Args:
@@ -466,12 +465,14 @@ def establish_learner_connection(
def learner_service_client( def learner_service_client(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 50051, port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: ) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
""" """Return a client for the learner service.
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it. So we need to create only one client and reuse it.
Returns:
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
""" """
channel = grpc.insecure_channel( channel = grpc.insecure_channel(
@@ -487,15 +488,17 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
): ) -> None:
"""Receive parameters from the learner. """Receive parameters from the learner.
Args: Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor. cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters. parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown. shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
""" """
logging.info("[ACTOR] Start receiving parameters from the Learner") logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg): if not use_threads(cfg):
@@ -539,11 +542,10 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue, transitions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> services_pb2.Empty: ) -> None:
""" """Send transitions to the learner.
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes: This function continuously retrieves messages from the queue and processes:
@@ -551,6 +553,13 @@ def send_transitions(
- A batch of transitions (observation, action, reward, next observation) is collected. - A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch. - Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
transitions_queue (Queue): The queue to receive the transitions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
""" """
if not use_threads(cfg): if not use_threads(cfg):
@@ -589,17 +598,23 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> services_pb2.Empty: ) -> None:
""" """Send interactions to the learner.
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes: This function continuously retrieves messages from the queue and processes:
- Interaction Messages: - Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings. - Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner. - The message is serialized using `pickle` and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
interactions_queue (Queue): The queue to receive the interactions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
""" """
if not use_threads(cfg): if not use_threads(cfg):
@@ -642,7 +657,7 @@ def transitions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
transitions_queue: Queue, transitions_queue: Queue,
timeout: float, timeout: float,
) -> services_pb2.Empty: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=timeout) message = transitions_queue.get(block=True, timeout=timeout)
@@ -661,7 +676,7 @@ def interactions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
interactions_queue: Queue, interactions_queue: Queue,
timeout: float, timeout: float,
) -> services_pb2.Empty: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = interactions_queue.get(block=True, timeout=timeout) message = interactions_queue.get(block=True, timeout=timeout)
-2
View File
@@ -44,8 +44,6 @@ For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide https://github.com/michel-aractingi/lerobot-hilserl-guide
""" """
from __future__ import annotations
import logging import logging
import os import os
import shutil import shutil