mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
fix test for flat dict structure
This commit is contained in:
@@ -165,30 +165,32 @@ def test_training_stats_defaults():
|
|||||||
def test_get_weights_returns_policy_state_dict():
|
def test_get_weights_returns_policy_state_dict():
|
||||||
algorithm, policy = _make_algorithm()
|
algorithm, policy = _make_algorithm()
|
||||||
weights = algorithm.get_weights()
|
weights = algorithm.get_weights()
|
||||||
for key in policy.state_dict():
|
assert "policy" in weights
|
||||||
assert key in weights
|
actor_state_dict = policy.actor.state_dict()
|
||||||
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
|
for key in actor_state_dict:
|
||||||
|
assert key in weights["policy"]
|
||||||
|
assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu())
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_includes_discrete_critic_when_present():
|
def test_get_weights_includes_discrete_critic_when_present():
|
||||||
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||||
weights = algorithm.get_weights()
|
weights = algorithm.get_weights()
|
||||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
assert "discrete_critic" in weights
|
||||||
assert len(dc_keys) > 0
|
assert len(weights["discrete_critic"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_excludes_discrete_critic_when_absent():
|
def test_get_weights_excludes_discrete_critic_when_absent():
|
||||||
algorithm, _ = _make_algorithm()
|
algorithm, _ = _make_algorithm()
|
||||||
weights = algorithm.get_weights()
|
weights = algorithm.get_weights()
|
||||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
assert "discrete_critic" not in weights
|
||||||
assert len(dc_keys) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_are_on_cpu():
|
def test_get_weights_are_on_cpu():
|
||||||
algorithm, _ = _make_algorithm()
|
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||||
weights = algorithm.get_weights()
|
weights = algorithm.get_weights()
|
||||||
for key, tensor in weights.items():
|
for group_name, state_dict in weights.items():
|
||||||
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
|
for key, tensor in state_dict.items():
|
||||||
|
assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -408,10 +410,11 @@ def test_load_weights_round_trip():
|
|||||||
weights = algo_src.get_weights()
|
weights = algo_src.get_weights()
|
||||||
algo_dst.load_weights(weights, device="cpu")
|
algo_dst.load_weights(weights, device="cpu")
|
||||||
|
|
||||||
for key in weights:
|
dst_actor_state_dict = algo_dst.policy.actor.state_dict()
|
||||||
|
for key, tensor in weights["policy"].items():
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
algo_dst.policy.state_dict()[key].cpu(),
|
dst_actor_state_dict[key].cpu(),
|
||||||
weights[key].cpu(),
|
tensor.cpu(),
|
||||||
), f"Policy param '{key}' mismatch after load_weights"
|
), f"Policy param '{key}' mismatch after load_weights"
|
||||||
|
|
||||||
|
|
||||||
@@ -426,12 +429,13 @@ def test_load_weights_round_trip_with_discrete_critic():
|
|||||||
weights = algo_src.get_weights()
|
weights = algo_src.get_weights()
|
||||||
algo_dst.load_weights(weights, device="cpu")
|
algo_dst.load_weights(weights, device="cpu")
|
||||||
|
|
||||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
assert "discrete_critic" in weights
|
||||||
assert len(dc_keys) > 0
|
assert len(weights["discrete_critic"]) > 0
|
||||||
for key in dc_keys:
|
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
|
||||||
|
for key, tensor in weights["discrete_critic"].items():
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
algo_dst.policy.state_dict()[key].cpu(),
|
dst_dc_state_dict[key].cpu(),
|
||||||
weights[key].cpu(),
|
tensor.cpu(),
|
||||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user