mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
refactor(tests): remove grpc import checks from test files for cleaner code
This commit is contained in:
@@ -12,15 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Reinforcement learning modules.
|
||||||
|
|
||||||
|
Distributed actor / learner entry points (``actor``, ``learner``,
|
||||||
|
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
|
||||||
|
buffer, data sources and trainer are gRPC-free and usable standalone.
|
||||||
"""
|
"""
|
||||||
Reinforcement learning modules.
|
|
||||||
|
|
||||||
Requires: ``pip install 'lerobot[hilserl]'``
|
|
||||||
"""
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
|
||||||
|
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
|
||||||
|
|
||||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||||
@@ -28,7 +25,7 @@ from .algorithms.factory import (
|
|||||||
make_algorithm as make_algorithm,
|
make_algorithm as make_algorithm,
|
||||||
make_algorithm_config as make_algorithm_config,
|
make_algorithm_config as make_algorithm_config,
|
||||||
)
|
)
|
||||||
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
|
||||||
from .buffer import ReplayBuffer as ReplayBuffer
|
from .buffer import ReplayBuffer as ReplayBuffer
|
||||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||||
from .trainer import RLTrainer as RLTrainer
|
from .trainer import RLTrainer as RLTrainer
|
||||||
@@ -39,7 +36,6 @@ __all__ = [
|
|||||||
"TrainingStats",
|
"TrainingStats",
|
||||||
"make_algorithm",
|
"make_algorithm",
|
||||||
"make_algorithm_config",
|
"make_algorithm_config",
|
||||||
"SACAlgorithm",
|
|
||||||
"SACAlgorithmConfig",
|
"SACAlgorithmConfig",
|
||||||
"RLTrainer",
|
"RLTrainer",
|
||||||
"ReplayBuffer",
|
"ReplayBuffer",
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@@ -240,7 +237,7 @@ def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||||
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
|
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
|
||||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||||
algorithm, policy = _make_algorithm(config)
|
algorithm, policy = _make_algorithm(config)
|
||||||
|
|
||||||
@@ -270,7 +267,7 @@ def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||||
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||||
algorithm, policy = _make_algorithm(config)
|
algorithm, policy = _make_algorithm(config)
|
||||||
|
|
||||||
@@ -331,7 +328,7 @@ def test_gaussian_actor_policy_with_pretrained_encoder(
|
|||||||
assert actor_loss.shape == ()
|
assert actor_loss.shape == ()
|
||||||
|
|
||||||
|
|
||||||
def test_sac_training_with_shared_encoder():
|
def test_gaussian_actor_training_with_shared_encoder():
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
action_dim = 10
|
action_dim = 10
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
@@ -358,7 +355,7 @@ def test_sac_training_with_shared_encoder():
|
|||||||
algorithm.optimizers["actor"].step()
|
algorithm.optimizers["actor"].step()
|
||||||
|
|
||||||
|
|
||||||
def test_sac_training_with_discrete_critic():
|
def test_gaussian_actor_training_with_discrete_critic():
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
continuous_action_dim = 9
|
continuous_action_dim = 9
|
||||||
full_action_dim = continuous_action_dim + 1
|
full_action_dim = continuous_action_dim + 1
|
||||||
@@ -475,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
|||||||
|
|
||||||
def test_gaussian_actor_policy_save_and_load(tmp_path):
|
def test_gaussian_actor_policy_save_and_load(tmp_path):
|
||||||
"""Test that the policy can be saved and loaded from pretrained."""
|
"""Test that the policy can be saved and loaded from pretrained."""
|
||||||
root = tmp_path / "test_sac_save_and_load"
|
root = tmp_path / "test_gaussian_actor_save_and_load"
|
||||||
|
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
action_dim = 10
|
action_dim = 10
|
||||||
@@ -506,7 +503,7 @@ def test_gaussian_actor_policy_save_and_load(tmp_path):
|
|||||||
|
|
||||||
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
|
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
|
||||||
|
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
action_dim = 6
|
action_dim = 6
|
||||||
|
|||||||
@@ -13,10 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
from torch.multiprocessing import Queue as TorchMPQueue
|
from torch.multiprocessing import Queue as TorchMPQueue
|
||||||
|
|
||||||
from lerobot.rl.queue import get_last_item_from_queue
|
from lerobot.rl.queue import get_last_item_from_queue
|
||||||
|
|||||||
@@ -16,9 +16,6 @@
|
|||||||
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
|
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
@@ -31,7 +28,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
|||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers (reuse patterns from tests/policies/test_sac_policy.py)
|
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from collections.abc import Callable
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
import torch # noqa: E402
|
import torch # noqa: E402
|
||||||
|
|||||||
Reference in New Issue
Block a user