mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor(policies): rename policies/sac → policies/gaussian_actor
This commit is contained in:
@@ -17,8 +17,8 @@
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_classifier_output():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
@@ -79,7 +79,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
@@ -118,7 +118,7 @@ def test_multiclass_classifier():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
@@ -130,7 +130,7 @@ def test_default_device():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
Reference in New Issue
Block a user