mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
refactor(rl): update type hints for learner and actor functions
This commit is contained in:
+37
-22
@@ -46,11 +46,10 @@ For more details on the complete HILSerl training workflow, see:
|
|||||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -433,10 +432,10 @@ def act_with_policy(
|
|||||||
|
|
||||||
|
|
||||||
def establish_learner_connection(
|
def establish_learner_connection(
|
||||||
stub: services_pb2_grpc.LearnerServiceStub,
|
stub: "services_pb2_grpc.LearnerServiceStub",
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
attempts: int = 30,
|
attempts: int = 30,
|
||||||
):
|
) -> bool:
|
||||||
"""Establish a connection with the learner.
|
"""Establish a connection with the learner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -466,12 +465,14 @@ def establish_learner_connection(
|
|||||||
def learner_service_client(
|
def learner_service_client(
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 50051,
|
port: int = 50051,
|
||||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
|
||||||
"""
|
"""Return a client for the learner service.
|
||||||
Returns a client for the learner service.
|
|
||||||
|
|
||||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||||
So we need to create only one client and reuse it.
|
So we need to create only one client and reuse it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
channel = grpc.insecure_channel(
|
channel = grpc.insecure_channel(
|
||||||
@@ -487,15 +488,17 @@ def receive_policy(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
):
|
) -> None:
|
||||||
"""Receive parameters from the learner.
|
"""Receive parameters from the learner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||||
parameters_queue (Queue): The queue to receive the parameters.
|
parameters_queue (Queue): The queue to receive the parameters.
|
||||||
shutdown_event (Event): The event to check if the process should shutdown.
|
shutdown_event (Event): The event to check if the process should shutdown.
|
||||||
|
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||||
|
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||||
"""
|
"""
|
||||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
@@ -539,11 +542,10 @@ def send_transitions(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
) -> services_pb2.Empty:
|
) -> None:
|
||||||
"""
|
"""Send transitions to the learner.
|
||||||
Sends transitions to the learner.
|
|
||||||
|
|
||||||
This function continuously retrieves messages from the queue and processes:
|
This function continuously retrieves messages from the queue and processes:
|
||||||
|
|
||||||
@@ -551,6 +553,13 @@ def send_transitions(
|
|||||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||||
|
transitions_queue (Queue): The queue to receive the transitions.
|
||||||
|
shutdown_event (Event): The event to check if the process should shutdown.
|
||||||
|
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||||
|
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
@@ -589,17 +598,23 @@ def send_interactions(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
) -> services_pb2.Empty:
|
) -> None:
|
||||||
"""
|
"""Send interactions to the learner.
|
||||||
Sends interactions to the learner.
|
|
||||||
|
|
||||||
This function continuously retrieves messages from the queue and processes:
|
This function continuously retrieves messages from the queue and processes:
|
||||||
|
|
||||||
- Interaction Messages:
|
- Interaction Messages:
|
||||||
- Contains useful statistics about episodic rewards and policy timings.
|
- Contains useful statistics about episodic rewards and policy timings.
|
||||||
- The message is serialized using `pickle` and sent to the learner.
|
- The message is serialized using `pickle` and sent to the learner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||||
|
interactions_queue (Queue): The queue to receive the interactions.
|
||||||
|
shutdown_event (Event): The event to check if the process should shutdown.
|
||||||
|
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||||
|
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
@@ -642,7 +657,7 @@ def transitions_stream(
|
|||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> services_pb2.Empty:
|
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = transitions_queue.get(block=True, timeout=timeout)
|
message = transitions_queue.get(block=True, timeout=timeout)
|
||||||
@@ -661,7 +676,7 @@ def interactions_stream(
|
|||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> services_pb2.Empty:
|
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = interactions_queue.get(block=True, timeout=timeout)
|
message = interactions_queue.get(block=True, timeout=timeout)
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ For more details on the complete HILSerl training workflow, see:
|
|||||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
Reference in New Issue
Block a user