mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
refactor(tests): remove grpc import checks from test files for cleaner code
This commit is contained in:
@@ -15,9 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -240,7 +237,7 @@ def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
|
||||
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
@@ -270,7 +267,7 @@ def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
@@ -331,7 +328,7 @@ def test_gaussian_actor_policy_with_pretrained_encoder(
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_training_with_shared_encoder():
|
||||
def test_gaussian_actor_training_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
@@ -358,7 +355,7 @@ def test_sac_training_with_shared_encoder():
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_training_with_discrete_critic():
|
||||
def test_gaussian_actor_training_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1
|
||||
@@ -475,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
|
||||
def test_gaussian_actor_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
root = tmp_path / "test_gaussian_actor_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
@@ -506,7 +503,7 @@ def test_gaussian_actor_policy_save_and_load(tmp_path):
|
||||
|
||||
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
||||
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
Reference in New Issue
Block a user