fix test for flat dict structure

This commit is contained in:
Khalil Meftah
2026-04-21 12:06:25 +02:00
parent 8191d2d87f
commit 8065bf15c7
+23 -19
View File
@@ -165,30 +165,32 @@ def test_training_stats_defaults():
def test_get_weights_returns_policy_state_dict():
algorithm, policy = _make_algorithm()
weights = algorithm.get_weights()
for key in policy.state_dict():
assert key in weights
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
assert "policy" in weights
actor_state_dict = policy.actor.state_dict()
for key in actor_state_dict:
assert key in weights["policy"]
assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu())
def test_get_weights_includes_discrete_critic_when_present():
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
weights = algorithm.get_weights()
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
def test_get_weights_excludes_discrete_critic_when_absent():
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) == 0
assert "discrete_critic" not in weights
def test_get_weights_are_on_cpu():
algorithm, _ = _make_algorithm()
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
weights = algorithm.get_weights()
for key, tensor in weights.items():
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
for group_name, state_dict in weights.items():
for key, tensor in state_dict.items():
assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU"
# ===========================================================================
@@ -408,10 +410,11 @@ def test_load_weights_round_trip():
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
for key in weights:
dst_actor_state_dict = algo_dst.policy.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(
algo_dst.policy.state_dict()[key].cpu(),
weights[key].cpu(),
dst_actor_state_dict[key].cpu(),
tensor.cpu(),
), f"Policy param '{key}' mismatch after load_weights"
@@ -426,12 +429,13 @@ def test_load_weights_round_trip_with_discrete_critic():
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
for key in dc_keys:
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(
algo_dst.policy.state_dict()[key].cpu(),
weights[key].cpu(),
dst_dc_state_dict[key].cpu(),
tensor.cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights"