mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
fix test for flat dict structure
This commit is contained in:
@@ -165,30 +165,32 @@ def test_training_stats_defaults():
|
||||
def test_get_weights_returns_policy_state_dict():
|
||||
algorithm, policy = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key in policy.state_dict():
|
||||
assert key in weights
|
||||
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
|
||||
assert "policy" in weights
|
||||
actor_state_dict = policy.actor.state_dict()
|
||||
for key in actor_state_dict:
|
||||
assert key in weights["policy"]
|
||||
assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu())
|
||||
|
||||
|
||||
def test_get_weights_includes_discrete_critic_when_present():
|
||||
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
assert "discrete_critic" in weights
|
||||
assert len(weights["discrete_critic"]) > 0
|
||||
|
||||
|
||||
def test_get_weights_excludes_discrete_critic_when_absent():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) == 0
|
||||
assert "discrete_critic" not in weights
|
||||
|
||||
|
||||
def test_get_weights_are_on_cpu():
|
||||
algorithm, _ = _make_algorithm()
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
weights = algorithm.get_weights()
|
||||
for key, tensor in weights.items():
|
||||
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
|
||||
for group_name, state_dict in weights.items():
|
||||
for key, tensor in state_dict.items():
|
||||
assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -408,10 +410,11 @@ def test_load_weights_round_trip():
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
for key in weights:
|
||||
dst_actor_state_dict = algo_dst.policy.actor.state_dict()
|
||||
for key, tensor in weights["policy"].items():
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
dst_actor_state_dict[key].cpu(),
|
||||
tensor.cpu(),
|
||||
), f"Policy param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
@@ -426,12 +429,13 @@ def test_load_weights_round_trip_with_discrete_critic():
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
for key in dc_keys:
|
||||
assert "discrete_critic" in weights
|
||||
assert len(weights["discrete_critic"]) > 0
|
||||
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
|
||||
for key, tensor in weights["discrete_critic"].items():
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
dst_dc_state_dict[key].cpu(),
|
||||
tensor.cpu(),
|
||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user