refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

This commit is contained in:
Khalil Meftah
2026-04-27 13:39:03 +02:00
parent 21c16a27f0
commit 9ce9e01469
9 changed files with 86 additions and 56 deletions
+9 -11
View File
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type():
class FakeConfig:
type = "unknown_algo"
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
def test_unknown_algorithm_name_raises_in_registry():
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
with pytest.raises(KeyError):
RLAlgorithmConfig.get_choice_class("unknown_algo")
# ===========================================================================
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)