refactor(rl): move actor weight-sync wire format from policy to algorithm

This commit is contained in:
Khalil Meftah
2026-05-08 21:57:21 +02:00
parent 23811b720d
commit ef927ac830
4 changed files with 26 additions and 23 deletions
+10 -6
View File
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
# Learner side: train the algorithm so its weights diverge from init.
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
# Learner side: train the source algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy, no algorithm/optimizer.
# Actor side: fresh policy + fresh algorithm holding it.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
algo_actor = SACAlgorithm(
policy=policy_actor,
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
policy_actor.load_actor_weights(weights, device="cpu")
algo_actor.load_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by load_actor_weights"
f"Actor param '{key}' not synced by algorithm.load_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by load_actor_weights"
f"Discrete critic param '{key}' not synced by algorithm.load_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).