refactor(tests): remove grpc import checks from test files for cleaner code

This commit is contained in:
Khalil Meftah
2026-04-27 16:20:13 +02:00
parent 47be90f040
commit 577f14337a
8 changed files with 13 additions and 38 deletions
+6 -9
View File
@@ -15,9 +15,6 @@
# limitations under the License.
import pytest
pytest.importorskip("grpc")
import torch
from torch import Tensor, nn
@@ -240,7 +237,7 @@ def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_d
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
@@ -270,7 +267,7 @@ def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
@@ -331,7 +328,7 @@ def test_gaussian_actor_policy_with_pretrained_encoder(
assert actor_loss.shape == ()
def test_sac_training_with_shared_encoder():
def test_gaussian_actor_training_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
@@ -358,7 +355,7 @@ def test_sac_training_with_shared_encoder():
algorithm.optimizers["actor"].step()
def test_sac_training_with_discrete_critic():
def test_gaussian_actor_training_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1
@@ -475,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load"
root = tmp_path / "test_gaussian_actor_save_and_load"
state_dim = 10
action_dim = 10
@@ -506,7 +503,7 @@ def test_gaussian_actor_policy_save_and_load(tmp_path):
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete"
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
state_dim = 10
action_dim = 6
-4
View File
@@ -13,10 +13,6 @@
# limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import pytest
pytest.importorskip("grpc")
import torch
from lerobot.rl.buffer import ReplayBuffer
-4
View File
@@ -18,10 +18,6 @@ import threading
import time
from queue import Queue
import pytest
pytest.importorskip("grpc")
from torch.multiprocessing import Queue as TorchMPQueue
from lerobot.rl.queue import get_last_item_from_queue
+1 -4
View File
@@ -16,9 +16,6 @@
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest
pytest.importorskip("grpc")
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -31,7 +28,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import set_seed
# ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_sac_policy.py)
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
# ---------------------------------------------------------------------------
-4
View File
@@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytest.importorskip("grpc")
import torch
from torch import Tensor
-2
View File
@@ -22,8 +22,6 @@ from unittest.mock import patch
import pytest
pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler
-1
View File
@@ -19,7 +19,6 @@ from collections.abc import Callable
import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402