mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
refactor: decouple policy from algorithm
This commit is contained in:
@@ -158,54 +158,55 @@ def test_training_stats_defaults():
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_get_weights_returns_actor_state_dict():
|
||||
def test_get_weights_returns_policy_state_dict():
|
||||
algorithm, policy = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
assert "policy" in weights
|
||||
for key in policy.actor.state_dict():
|
||||
assert key in weights["policy"]
|
||||
assert torch.equal(weights["policy"][key].cpu(), policy.actor.state_dict()[key].cpu())
|
||||
for key in policy.state_dict():
|
||||
assert key in weights
|
||||
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
|
||||
|
||||
|
||||
def test_get_weights_includes_discrete_critic_when_present():
|
||||
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
weights = algorithm.get_weights()
|
||||
assert "discrete_critic" in weights
|
||||
for key in policy.discrete_critic.state_dict():
|
||||
assert key in weights["discrete_critic"]
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
|
||||
def test_get_weights_excludes_discrete_critic_when_absent():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
assert "discrete_critic" not in weights
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) == 0
|
||||
|
||||
|
||||
def test_get_weights_are_on_cpu():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key, tensor in weights["policy"].items():
|
||||
for key, tensor in weights.items():
|
||||
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# select_action
|
||||
# select_action (lives on the policy, not the algorithm)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_select_action_returns_correct_shape():
|
||||
action_dim = 6
|
||||
algorithm, _ = _make_algorithm(state_dim=10, action_dim=action_dim)
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=action_dim)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = algorithm.select_action(obs)
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
|
||||
def test_select_action_with_discrete_critic():
|
||||
continuous_dim = 5
|
||||
algorithm, _ = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = algorithm.select_action(obs)
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (continuous_dim + 1,)
|
||||
|
||||
|
||||
@@ -298,12 +299,12 @@ def test_update_utd_ratio_3_critic_warmup_changes_weights():
|
||||
"""With utd_ratio=3, critic weights should change after update (3 critic steps)."""
|
||||
algorithm, policy = _make_algorithm(utd_ratio=3)
|
||||
|
||||
critic_params_before = {n: p.clone() for n, p in policy.critic_ensemble.named_parameters()}
|
||||
critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()}
|
||||
|
||||
algorithm.update(_batch_iterator())
|
||||
|
||||
changed = False
|
||||
for n, p in policy.critic_ensemble.named_parameters():
|
||||
for n, p in algorithm.critic_ensemble.named_parameters():
|
||||
if not torch.equal(p, critic_params_before[n]):
|
||||
changed = True
|
||||
break
|
||||
@@ -403,11 +404,11 @@ def test_load_weights_round_trip():
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
for key in weights["policy"]:
|
||||
for key in weights:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.actor.state_dict()[key].cpu(),
|
||||
weights["policy"][key].cpu(),
|
||||
), f"Actor param '{key}' mismatch after load_weights"
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Policy param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_round_trip_with_discrete_critic():
|
||||
@@ -421,17 +422,19 @@ def test_load_weights_round_trip_with_discrete_critic():
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
for key in weights["discrete_critic"]:
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
for key in dc_keys:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.discrete_critic.state_dict()[key].cpu(),
|
||||
weights["discrete_critic"][key].cpu(),
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_ignores_missing_discrete_critic():
|
||||
"""load_weights should not fail when weights lack discrete_critic on a non-discrete policy."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = {"policy": algorithm.get_weights()["policy"]}
|
||||
weights = algorithm.get_weights()
|
||||
algorithm.load_weights(weights, device="cpu")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user