refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

This commit is contained in:
Khalil Meftah
2026-04-24 13:18:33 +02:00
parent 06255996ea
commit 1ed32210c7
9 changed files with 162 additions and 190 deletions
+10 -9
View File
@@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic():
def test_sac_algorithm_target_entropy():
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
_, policy = _make_algorithm(config)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -5.0
def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -3.5
@@ -435,8 +432,8 @@ def test_sac_algorithm_temperature():
def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.critic_target_update_weight = 1.0
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
algorithm, policy = _make_algorithm(config)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.num_critics = num_critics
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
assert len(algorithm.critic_ensemble.critics) == num_critics