mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
refactor(rl): move grpcio guards to runtime entry points
This commit is contained in:
@@ -333,9 +333,6 @@ ignore = [
|
|||||||
"__init__.py" = ["F401", "F403", "E402"]
|
"__init__.py" = ["F401", "F403", "E402"]
|
||||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||||
"src/lerobot/rl/actor.py" = ["E402"]
|
|
||||||
"src/lerobot/rl/learner.py" = ["E402"]
|
|
||||||
"src/lerobot/rl/learner_service.py" = ["E402"]
|
|
||||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
|
|||||||
+28
-15
@@ -46,18 +46,15 @@ For more details on the complete HILSerl training workflow, see:
|
|||||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
@@ -69,16 +66,8 @@ from lerobot.processor import TransitionKey
|
|||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
|
||||||
from lerobot.transport.utils import (
|
|
||||||
bytes_to_state_dict,
|
|
||||||
grpc_channel_options,
|
|
||||||
python_object_to_bytes,
|
|
||||||
receive_bytes_in_chunks,
|
|
||||||
send_bytes_in_chunks,
|
|
||||||
transitions_to_bytes,
|
|
||||||
)
|
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
|
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
@@ -91,6 +80,29 @@ from lerobot.utils.utils import (
|
|||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _grpc_available:
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||||
|
from lerobot.transport.utils import (
|
||||||
|
bytes_to_state_dict,
|
||||||
|
grpc_channel_options,
|
||||||
|
python_object_to_bytes,
|
||||||
|
receive_bytes_in_chunks,
|
||||||
|
send_bytes_in_chunks,
|
||||||
|
transitions_to_bytes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grpc = None
|
||||||
|
services_pb2 = None
|
||||||
|
services_pb2_grpc = None
|
||||||
|
bytes_to_state_dict = None
|
||||||
|
grpc_channel_options = None
|
||||||
|
python_object_to_bytes = None
|
||||||
|
receive_bytes_in_chunks = None
|
||||||
|
send_bytes_in_chunks = None
|
||||||
|
transitions_to_bytes = None
|
||||||
|
|
||||||
from .gym_manipulator import (
|
from .gym_manipulator import (
|
||||||
make_processors,
|
make_processors,
|
||||||
make_robot_env,
|
make_robot_env,
|
||||||
@@ -105,6 +117,7 @@ from .train_rl import TrainRLServerPipelineConfig
|
|||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||||
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
display_pid = False
|
display_pid = False
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
|
|||||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
|
|
||||||
@@ -312,8 +313,6 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
|
|||||||
# Check if this is a GymHIL simulation environment
|
# Check if this is a GymHIL simulation environment
|
||||||
if cfg.name == "gym_hil":
|
if cfg.name == "gym_hil":
|
||||||
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
|
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
|
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
|
||||||
import gym_hil # noqa: F401
|
import gym_hil # noqa: F401
|
||||||
|
|
||||||
|
|||||||
+23
-13
@@ -44,6 +44,8 @@ For more details on the complete HILSerl training workflow, see:
|
|||||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -51,13 +53,8 @@ import time
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -78,13 +75,6 @@ from lerobot.policies import make_policy, make_pre_post_processors
|
|||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
from lerobot.transport import services_pb2_grpc
|
|
||||||
from lerobot.transport.utils import (
|
|
||||||
MAX_MESSAGE_SIZE,
|
|
||||||
bytes_to_python_object,
|
|
||||||
bytes_to_transitions,
|
|
||||||
state_to_bytes,
|
|
||||||
)
|
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
CHECKPOINTS_DIR,
|
CHECKPOINTS_DIR,
|
||||||
@@ -93,6 +83,7 @@ from lerobot.utils.constants import (
|
|||||||
TRAINING_STATE_DIR,
|
TRAINING_STATE_DIR,
|
||||||
)
|
)
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
|
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||||
from lerobot.utils.process import ProcessSignalHandler
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
@@ -100,6 +91,24 @@ from lerobot.utils.utils import (
|
|||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _grpc_available:
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from lerobot.transport import services_pb2_grpc
|
||||||
|
from lerobot.transport.utils import (
|
||||||
|
MAX_MESSAGE_SIZE,
|
||||||
|
bytes_to_python_object,
|
||||||
|
bytes_to_transitions,
|
||||||
|
state_to_bytes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grpc = None
|
||||||
|
services_pb2_grpc = None
|
||||||
|
MAX_MESSAGE_SIZE = None
|
||||||
|
bytes_to_python_object = None
|
||||||
|
bytes_to_transitions = None
|
||||||
|
state_to_bytes = None
|
||||||
|
|
||||||
from .algorithms.base import RLAlgorithm
|
from .algorithms.base import RLAlgorithm
|
||||||
from .algorithms.factory import make_algorithm
|
from .algorithms.factory import make_algorithm
|
||||||
from .buffer import ReplayBuffer
|
from .buffer import ReplayBuffer
|
||||||
@@ -111,6 +120,7 @@ from .trainer import RLTrainer
|
|||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||||
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|||||||
@@ -18,13 +18,21 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||||
|
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
if TYPE_CHECKING or _grpc_available:
|
||||||
|
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||||
|
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||||
|
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
|
||||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
else:
|
||||||
|
services_pb2 = None
|
||||||
|
services_pb2_grpc = None
|
||||||
|
receive_bytes_in_chunks = None
|
||||||
|
send_bytes_in_chunks = None
|
||||||
|
_ServicerBase = object
|
||||||
|
|
||||||
from .queue import get_last_item_from_queue
|
from .queue import get_last_item_from_queue
|
||||||
|
|
||||||
@@ -32,7 +40,7 @@ MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
|||||||
SHUTDOWN_TIMEOUT = 10
|
SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
class LearnerService(_ServicerBase):
|
||||||
"""
|
"""
|
||||||
Implementation of the LearnerService gRPC service
|
Implementation of the LearnerService gRPC service
|
||||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||||
@@ -48,6 +56,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|||||||
interaction_message_queue: Queue,
|
interaction_message_queue: Queue,
|
||||||
queue_get_timeout: float = 0.001,
|
queue_get_timeout: float = 0.001,
|
||||||
):
|
):
|
||||||
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
self.shutdown_event = shutdown_event
|
self.shutdown_event = shutdown_event
|
||||||
self.parameters_queue = parameters_queue
|
self.parameters_queue = parameters_queue
|
||||||
self.seconds_between_pushes = seconds_between_pushes
|
self.seconds_between_pushes = seconds_between_pushes
|
||||||
|
|||||||
@@ -132,6 +132,7 @@ _faker_available = is_package_available("faker")
|
|||||||
_pynput_available = is_package_available("pynput")
|
_pynput_available = is_package_available("pynput")
|
||||||
_pygame_available = is_package_available("pygame")
|
_pygame_available = is_package_available("pygame")
|
||||||
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
|
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
|
||||||
|
_grpc_available = is_package_available("grpcio", import_name="grpc")
|
||||||
_wallx_deps_available = (
|
_wallx_deps_available = (
|
||||||
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
|
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user