From 8065bf15c74df99b75e4d6260984b76c1d163ffd Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 21 Apr 2026 12:06:25 +0200 Subject: [PATCH] fix test for flat dict structure --- tests/rl/test_sac_algorithm.py | 42 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index df69b7312..b25472108 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -165,30 +165,32 @@ def test_training_stats_defaults(): def test_get_weights_returns_policy_state_dict(): algorithm, policy = _make_algorithm() weights = algorithm.get_weights() - for key in policy.state_dict(): - assert key in weights - assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu()) + assert "policy" in weights + actor_state_dict = policy.actor.state_dict() + for key in actor_state_dict: + assert key in weights["policy"] + assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu()) def test_get_weights_includes_discrete_critic_when_present(): - algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6) + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) weights = algorithm.get_weights() - dc_keys = [k for k in weights if k.startswith("discrete_critic.")] - assert len(dc_keys) > 0 + assert "discrete_critic" in weights + assert len(weights["discrete_critic"]) > 0 def test_get_weights_excludes_discrete_critic_when_absent(): algorithm, _ = _make_algorithm() weights = algorithm.get_weights() - dc_keys = [k for k in weights if k.startswith("discrete_critic.")] - assert len(dc_keys) == 0 + assert "discrete_critic" not in weights def test_get_weights_are_on_cpu(): - algorithm, _ = _make_algorithm() + algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) weights = algorithm.get_weights() - for key, tensor in weights.items(): - assert tensor.device == torch.device("cpu"), f"{key} is not on CPU" + for group_name, state_dict in weights.items(): + for key, tensor in state_dict.items(): + assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU" # =========================================================================== @@ -408,10 +410,11 @@ def test_load_weights_round_trip(): weights = algo_src.get_weights() algo_dst.load_weights(weights, device="cpu") - for key in weights: + dst_actor_state_dict = algo_dst.policy.actor.state_dict() + for key, tensor in weights["policy"].items(): assert torch.equal( - algo_dst.policy.state_dict()[key].cpu(), - weights[key].cpu(), + dst_actor_state_dict[key].cpu(), + tensor.cpu(), ), f"Policy param '{key}' mismatch after load_weights" @@ -426,12 +429,13 @@ def test_load_weights_round_trip_with_discrete_critic(): weights = algo_src.get_weights() algo_dst.load_weights(weights, device="cpu") - dc_keys = [k for k in weights if k.startswith("discrete_critic.")] - assert len(dc_keys) > 0 - for key in dc_keys: + assert "discrete_critic" in weights + assert len(weights["discrete_critic"]) > 0 + dst_dc_state_dict = algo_dst.discrete_critic.state_dict() + for key, tensor in weights["discrete_critic"].items(): assert torch.equal( - algo_dst.policy.state_dict()[key].cpu(), - weights[key].cpu(), + dst_dc_state_dict[key].cpu(), + tensor.cpu(), ), f"Discrete critic param '{key}' mismatch after load_weights"