refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

This commit is contained in:
Khalil Meftah
2026-04-24 13:18:33 +02:00
parent 06255996ea
commit 1ed32210c7
9 changed files with 162 additions and 190 deletions
+62 -20
View File
@@ -44,8 +44,6 @@ def _make_sac_config(
state_dim: int = 10,
action_dim: int = 6,
num_discrete_actions: int | None = None,
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> GaussianActorConfig:
config = GaussianActorConfig(
@@ -55,10 +53,7 @@ def _make_sac_config(
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
use_torch_compile=False,
)
if with_images:
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
@@ -83,14 +78,14 @@ def _make_algorithm(
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = utd_ratio
algo_config.policy_update_freq = policy_update_freq
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
return algorithm, policy
@@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered():
def test_sac_algorithm_config_from_policy_config():
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
"""from_policy_config embeds the policy config and uses SAC defaults."""
sac_cfg = _make_sac_config()
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.sac_config is sac_cfg
assert algo_cfg.utd_ratio == 4
assert algo_cfg.policy_update_freq == 2
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
assert algo_cfg.policy_config is sac_cfg
assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
# Defaults come from SACAlgorithmConfig, not from the policy config.
assert algo_cfg.utd_ratio == 1
assert algo_cfg.policy_update_freq == 1
assert algo_cfg.grad_clip_norm == 40.0
assert algo_cfg.actor_lr == 3e-4
# ===========================================================================
@@ -377,12 +375,14 @@ def test_actor_side_no_optimizers():
assert algorithm.optimizers == {}
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type():
@@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic():
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(
dst_dc_state_dict[key].cpu(),
dst_discrete_critic_state_dict[key].cpu(),
tensor.cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights"
@@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic():
algorithm.load_weights(weights, device="cpu")
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
# Learner side: train the algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
policy_actor.load_actor_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
changed = any(
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
for key in initial_discrete_critic_state_dict
if key in discrete_critic_state_dict
)
assert changed, "Discrete critic weights did not change between init and after sync"
# ===========================================================================
# TrainingStats generic losses dict
# ===========================================================================
@@ -468,8 +509,9 @@ def test_training_stats_generic_losses():
def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
sac_cfg = _make_sac_config()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = 2
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)