refactor(policies): rename policies/sac → policies/gaussian_actor

This commit is contained in:
Khalil Meftah
2026-04-23 19:13:18 +02:00
parent 8065bf15c7
commit 06255996ea
24 changed files with 185 additions and 168 deletions
@@ -17,8 +17,8 @@
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
@@ -38,7 +38,7 @@ def test_classifier_output():
@skip_if_package_missing("transformers")
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
@@ -79,7 +79,7 @@ def test_binary_classifier_with_default_params():
@skip_if_package_missing("transformers")
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
@@ -118,7 +118,7 @@ def test_multiclass_classifier():
@skip_if_package_missing("transformers")
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
@@ -130,7 +130,7 @@ def test_default_device():
@skip_if_package_missing("transformers")
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
@@ -17,19 +17,19 @@
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
GaussianActorConfig,
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
config = SACConfig()
def test_gaussian_actor_config_default_initialization():
config = GaussianActorConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -175,8 +175,8 @@ def test_concurrency_config():
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
@@ -190,7 +190,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -198,7 +198,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
config = GaussianActorConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -209,7 +209,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -22,8 +22,8 @@ import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -81,9 +81,9 @@ def test_mlp_with_custom_final_activation():
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
def test_gaussian_actor_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
GaussianActorPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
@@ -142,12 +142,12 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
@@ -167,7 +167,7 @@ def create_default_config(
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
@@ -186,9 +186,9 @@ def create_config_with_visual_input(
return config
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -197,9 +197,9 @@ def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -209,11 +209,11 @@ def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: i
assert selected_action.shape[-1] == action_dim
def test_sac_policy_select_action_with_discrete():
def test_gaussian_actor_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -225,9 +225,9 @@ def test_sac_policy_select_action_with_discrete():
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
@@ -307,7 +307,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
def test_gaussian_actor_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
@@ -415,7 +415,7 @@ def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -3.5
@@ -425,7 +425,7 @@ def test_sac_algorithm_temperature():
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.temperature == pytest.approx(1.0)
@@ -437,7 +437,7 @@ def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in algorithm.critic_ensemble.parameters():
@@ -472,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
algorithm.optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load"
@@ -481,10 +481,10 @@ def test_sac_policy_save_and_load(tmp_path):
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
@@ -503,7 +503,7 @@ def test_sac_policy_save_and_load(tmp_path):
assert torch.allclose(actions, loaded_actions)
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete"
@@ -512,11 +512,11 @@ def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None