mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(rl): move actor weight-sync wire format from policy to algorithm
This commit is contained in:
@@ -445,35 +445,39 @@ def test_load_weights_ignores_missing_discrete_critic():
|
||||
|
||||
|
||||
def test_actor_side_weight_sync_with_discrete_critic():
|
||||
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
|
||||
# Learner side: train the algorithm so its weights diverge from init.
|
||||
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
|
||||
# Learner side: train the source algorithm so its weights diverge from init.
|
||||
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
weights = algo_src.get_weights()
|
||||
|
||||
# Actor side: fresh policy, no algorithm/optimizer.
|
||||
# Actor side: fresh policy + fresh algorithm holding it.
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_actor = GaussianActorPolicy(config=sac_cfg)
|
||||
algo_actor = SACAlgorithm(
|
||||
policy=policy_actor,
|
||||
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
|
||||
)
|
||||
|
||||
# Snapshot initial actor state for the "did it change?" assertion below.
|
||||
initial_discrete_critic_state_dict = {
|
||||
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
|
||||
}
|
||||
|
||||
policy_actor.load_actor_weights(weights, device="cpu")
|
||||
algo_actor.load_weights(weights, device="cpu")
|
||||
|
||||
# Actor weights match the learner's exported actor state dict.
|
||||
actor_state_dict = policy_actor.actor.state_dict()
|
||||
for key, tensor in weights["policy"].items():
|
||||
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Actor param '{key}' not synced by load_actor_weights"
|
||||
f"Actor param '{key}' not synced by algorithm.load_weights"
|
||||
)
|
||||
|
||||
# Discrete critic weights match the learner's exported discrete critic.
|
||||
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
|
||||
for key, tensor in weights["discrete_critic"].items():
|
||||
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Discrete critic param '{key}' not synced by load_actor_weights"
|
||||
f"Discrete critic param '{key}' not synced by algorithm.load_weights"
|
||||
)
|
||||
|
||||
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
|
||||
|
||||
Reference in New Issue
Block a user