refactor: decouple policy from algorithm

This commit is contained in:
Khalil Meftah
2026-03-11 16:49:14 +01:00
parent 8d50be9faa
commit 1f5487eea8
12 changed files with 769 additions and 908 deletions
+2 -4
View File
@@ -356,7 +356,7 @@ def test_learner_algorithm_wiring():
# get_weights -> state_to_bytes round-trip
weights = algorithm.get_weights()
assert "policy" in weights
assert len(weights) > 0
serialized = state_to_bytes(weights)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
@@ -430,8 +430,6 @@ def test_initial_and_periodic_weight_push_consistency():
periodic_decoded = bytes_to_state_dict(periodic_bytes)
assert initial_decoded.keys() == periodic_decoded.keys()
for key in initial_decoded:
assert initial_decoded[key].keys() == periodic_decoded[key].keys()
def test_actor_side_algorithm_select_action_and_load_weights():
@@ -462,7 +460,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
# select_action should work
obs = {OBS_STATE: torch.randn(state_dim)}
action = algorithm.select_action(obs)
action = policy.select_action(obs)
assert action.shape == (action_dim,)
# Simulate receiving weights from learner