mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
chore(rl): manage import pattern in actor (#3564)
* chore(rl): manage import pattern in actor * chore(rl): optional grpc imports in learner; quote grpc ServicerContext types --------- Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
+1
-1
@@ -195,7 +195,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
|
||||
+57
-45
@@ -52,63 +52,75 @@ import time
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
import grpc # noqa: E402
|
||||
import torch # noqa: E402
|
||||
from torch import nn # noqa: E402
|
||||
from torch.multiprocessing import Queue # noqa: E402
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2 = None
|
||||
services_pb2_grpc = None
|
||||
bytes_to_state_dict = None
|
||||
grpc_channel_options = None
|
||||
python_object_to_bytes = None
|
||||
receive_bytes_in_chunks = None
|
||||
send_bytes_in_chunks = None
|
||||
transitions_to_bytes = None
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401, E402
|
||||
from lerobot.configs import parser # noqa: E402
|
||||
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import TransitionKey # noqa: E402
|
||||
from lerobot.robots import so_follower # noqa: F401, E402
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
|
||||
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
|
||||
from lerobot.transport.utils import ( # noqa: E402
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
|
||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from lerobot.utils.robot_utils import precise_sleep # noqa: E402
|
||||
from lerobot.utils.transition import ( # noqa: E402
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import ( # noqa: E402
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .algorithms.base import RLAlgorithm # noqa: E402
|
||||
from .algorithms.factory import make_algorithm # noqa: E402
|
||||
from .gym_manipulator import ( # noqa: E402
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.factory import make_algorithm
|
||||
from .gym_manipulator import (
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
reset_and_build_transition,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .queue import get_last_item_from_queue # noqa: E402
|
||||
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
|
||||
from .queue import get_last_item_from_queue
|
||||
from .train_rl import TrainRLServerPipelineConfig
|
||||
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
@@ -421,7 +433,7 @@ def act_with_policy(
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
stub: "services_pb2_grpc.LearnerServiceStub",
|
||||
shutdown_event: Any, # Event
|
||||
attempts: int = 30,
|
||||
) -> bool:
|
||||
@@ -454,7 +466,7 @@ def establish_learner_connection(
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
|
||||
"""Return a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
@@ -477,8 +489,8 @@ def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
@@ -531,8 +543,8 @@ def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Send transitions to the learner.
|
||||
|
||||
@@ -587,8 +599,8 @@ def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Send interactions to the learner.
|
||||
|
||||
@@ -646,7 +658,7 @@ def transitions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
transitions_queue: Queue,
|
||||
timeout: float,
|
||||
) -> Generator[Any, None, services_pb2.Empty]:
|
||||
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
@@ -665,7 +677,7 @@ def interactions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
interactions_queue: Queue,
|
||||
timeout: float,
|
||||
) -> Generator[Any, None, services_pb2.Empty]:
|
||||
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=timeout)
|
||||
|
||||
+41
-36
@@ -51,44 +51,47 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
import grpc # noqa: E402
|
||||
import torch # noqa: E402
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402
|
||||
from safetensors.torch import load_file as load_safetensors # noqa: E402
|
||||
from termcolor import colored # noqa: E402
|
||||
from torch import nn # noqa: E402
|
||||
from torch.multiprocessing import Queue # noqa: E402
|
||||
from torch.optim.optimizer import Optimizer # noqa: E402
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2_grpc = None
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401, E402
|
||||
from lerobot.common.train_utils import ( # noqa: E402
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger # noqa: E402
|
||||
from lerobot.configs import parser # noqa: E402
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402
|
||||
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
|
||||
from lerobot.robots import so_follower # noqa: F401, E402
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
|
||||
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
|
||||
from lerobot.transport import services_pb2_grpc # noqa: E402
|
||||
from lerobot.transport.utils import ( # noqa: E402
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.utils.constants import ( # noqa: E402
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ALGORITHM_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
@@ -97,26 +100,28 @@ from lerobot.utils.constants import ( # noqa: E402
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
|
||||
from lerobot.utils.io_utils import load_json, write_json # noqa: E402
|
||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from lerobot.utils.utils import ( # noqa: E402
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .algorithms.base import RLAlgorithm # noqa: E402
|
||||
from .algorithms.factory import make_algorithm # noqa: E402
|
||||
from .buffer import ReplayBuffer # noqa: E402
|
||||
from .data_sources import OnlineOfflineMixer # noqa: E402
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402
|
||||
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
|
||||
from .trainer import RLTrainer # noqa: E402
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.factory import make_algorithm
|
||||
from .buffer import ReplayBuffer
|
||||
from .data_sources import OnlineOfflineMixer
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
from .train_rl import TrainRLServerPipelineConfig
|
||||
from .trainer import RLTrainer
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
@@ -18,22 +18,32 @@
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.import_utils import _grpc_available
|
||||
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
from .queue import get_last_item_from_queue # noqa: E402
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2 = None
|
||||
services_pb2_grpc = None
|
||||
receive_bytes_in_chunks = None
|
||||
send_bytes_in_chunks = None
|
||||
_ServicerBase = object
|
||||
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
class LearnerService(_ServicerBase):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
@@ -56,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
self.queue_get_timeout = queue_get_timeout
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
def StreamParameters( # noqa: N802
|
||||
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
|
||||
):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
@@ -91,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
@@ -105,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
@@ -119,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
|
||||
Reference in New Issue
Block a user