refactor(tests): remove grpc import checks from test files for cleaner code

This commit is contained in:
Khalil Meftah
2026-04-27 16:20:13 +02:00
parent 47be90f040
commit 577f14337a
8 changed files with 13 additions and 38 deletions
+6 -10
View File
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Reinforcement learning modules.
Distributed actor / learner entry points (``actor``, ``learner``,
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
buffer, data sources and trainer are gRPC-free and usable standalone.
""" """
Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'``
"""
from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="hilserl", import_name="grpc")
from .algorithms.base import RLAlgorithm as RLAlgorithm from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
@@ -28,7 +25,7 @@ from .algorithms.factory import (
make_algorithm as make_algorithm, make_algorithm as make_algorithm,
make_algorithm_config as make_algorithm_config, make_algorithm_config as make_algorithm_config,
) )
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
from .trainer import RLTrainer as RLTrainer from .trainer import RLTrainer as RLTrainer
@@ -39,7 +36,6 @@ __all__ = [
"TrainingStats", "TrainingStats",
"make_algorithm", "make_algorithm",
"make_algorithm_config", "make_algorithm_config",
"SACAlgorithm",
"SACAlgorithmConfig", "SACAlgorithmConfig",
"RLTrainer", "RLTrainer",
"ReplayBuffer", "ReplayBuffer",
+6 -9
View File
@@ -15,9 +15,6 @@
# limitations under the License. # limitations under the License.
import pytest import pytest
pytest.importorskip("grpc")
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@@ -240,7 +237,7 @@ def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_d
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) @pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int): def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config) algorithm, policy = _make_algorithm(config)
@@ -270,7 +267,7 @@ def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) @pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int): def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config) algorithm, policy = _make_algorithm(config)
@@ -331,7 +328,7 @@ def test_gaussian_actor_policy_with_pretrained_encoder(
assert actor_loss.shape == () assert actor_loss.shape == ()
def test_sac_training_with_shared_encoder(): def test_gaussian_actor_training_with_shared_encoder():
batch_size = 2 batch_size = 2
action_dim = 10 action_dim = 10
state_dim = 10 state_dim = 10
@@ -358,7 +355,7 @@ def test_sac_training_with_shared_encoder():
algorithm.optimizers["actor"].step() algorithm.optimizers["actor"].step()
def test_sac_training_with_discrete_critic(): def test_gaussian_actor_training_with_discrete_critic():
batch_size = 2 batch_size = 2
continuous_action_dim = 9 continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 full_action_dim = continuous_action_dim + 1
@@ -475,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
def test_gaussian_actor_policy_save_and_load(tmp_path): def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained.""" """Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load" root = tmp_path / "test_gaussian_actor_save_and_load"
state_dim = 10 state_dim = 10
action_dim = 10 action_dim = 10
@@ -506,7 +503,7 @@ def test_gaussian_actor_policy_save_and_load(tmp_path):
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path): def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy.""" """Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete" root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
state_dim = 10 state_dim = 10
action_dim = 6 action_dim = 6
-4
View File
@@ -13,10 +13,6 @@
# limitations under the License. # limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer).""" """Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import pytest
pytest.importorskip("grpc")
import torch import torch
from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.buffer import ReplayBuffer
-4
View File
@@ -18,10 +18,6 @@ import threading
import time import time
from queue import Queue from queue import Queue
import pytest
pytest.importorskip("grpc")
from torch.multiprocessing import Queue as TorchMPQueue from torch.multiprocessing import Queue as TorchMPQueue
from lerobot.rl.queue import get_last_item_from_queue from lerobot.rl.queue import get_last_item_from_queue
+1 -4
View File
@@ -16,9 +16,6 @@
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation.""" """Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest import pytest
pytest.importorskip("grpc")
import torch import torch
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@@ -31,7 +28,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_sac_policy.py) # Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
-4
View File
@@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
pytest.importorskip("grpc")
import torch import torch
from torch import Tensor from torch import Tensor
-2
View File
@@ -22,8 +22,6 @@ from unittest.mock import patch
import pytest import pytest
pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.process import ProcessSignalHandler
-1
View File
@@ -19,7 +19,6 @@ from collections.abc import Callable
import pytest import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402 import torch # noqa: E402