mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
refactor: decouple policy from algorithm
This commit is contained in:
@@ -356,7 +356,7 @@ def test_learner_algorithm_wiring():
|
||||
|
||||
# get_weights -> state_to_bytes round-trip
|
||||
weights = algorithm.get_weights()
|
||||
assert "policy" in weights
|
||||
assert len(weights) > 0
|
||||
serialized = state_to_bytes(weights)
|
||||
assert isinstance(serialized, bytes)
|
||||
assert len(serialized) > 0
|
||||
@@ -430,8 +430,6 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
periodic_decoded = bytes_to_state_dict(periodic_bytes)
|
||||
|
||||
assert initial_decoded.keys() == periodic_decoded.keys()
|
||||
for key in initial_decoded:
|
||||
assert initial_decoded[key].keys() == periodic_decoded[key].keys()
|
||||
|
||||
|
||||
def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
@@ -462,7 +460,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
|
||||
# select_action should work
|
||||
obs = {OBS_STATE: torch.randn(state_dim)}
|
||||
action = algorithm.select_action(obs)
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
# Simulate receiving weights from learner
|
||||
|
||||
Reference in New Issue
Block a user