feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make  and  s, add abstract
 /  for algorithm-owned tensors (critics,
target nets, ), and persist them as a sibling
component next to . Replace the pickled
 side-file with an enriched
carrying both  and , so resume restores actor +
critics + target nets + temperature + optimizers + RNG + counters from
plain HF-standard files.
This commit is contained in:
Khalil Meftah
2026-05-08 21:24:23 +02:00
parent b1b2708e2f
commit 0944b84279
8 changed files with 382 additions and 24 deletions
+85
View File
@@ -515,3 +515,88 @@ def test_make_algorithm_builds_sac():
algorithm = make_algorithm(cfg=algo_config, policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.config.utd_ratio == 2
# ===========================================================================
# state_dict / load_state_dict (algorithm-side resume)
# ===========================================================================
def test_state_dict_contains_algorithm_owned_tensors():
"""state_dict should pack critics, target networks, and log_alpha (no encoder bloat)."""
algorithm, _ = _make_algorithm()
sd = algorithm.state_dict()
assert "log_alpha" in sd
assert any(k.startswith("critic_ensemble.") for k in sd)
assert any(k.startswith("critic_target.") for k in sd)
# encoder weights live on the policy and must not be duplicated here.
assert not any(".encoder." in k for k in sd)
def test_state_dict_includes_discrete_critic_target_when_present():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
sd = algorithm.state_dict()
assert any(k.startswith("discrete_critic_target.") for k in sd)
def test_load_state_dict_round_trip_restores_critics_and_log_alpha():
"""state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly."""
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
# Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete).
src.update(_batch_iterator(action_dim=7))
src.update(_batch_iterator(action_dim=7))
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
dst.make_optimizers_and_scheduler()
src_sd = src.state_dict()
dst.load_state_dict(src_sd)
dst_sd = dst.state_dict()
assert set(dst_sd) == set(src_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip"
def test_load_state_dict_preserves_log_alpha_parameter_identity():
"""The temperature optimizer holds a reference to log_alpha; identity must survive load."""
algorithm, _ = _make_algorithm()
log_alpha_id_before = id(algorithm.log_alpha)
optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0])
assert log_alpha_id_before == optimizer_param_id
new_state = algorithm.state_dict()
new_state["log_alpha"] = torch.tensor([0.42])
algorithm.load_state_dict(new_state)
assert id(algorithm.log_alpha) == log_alpha_id_before
assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before
assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42]))
def test_save_pretrained_round_trip_via_disk(tmp_path):
"""End-to-end: save_pretrained -> from_pretrained restores tensors and config."""
sac_cfg = _make_sac_config()
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
src.update(_batch_iterator())
save_dir = tmp_path / "algorithm"
src.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
assert (save_dir / "config.json").is_file()
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy)
src_sd = src.state_dict()
dst_sd = dst.state_dict()
assert set(src_sd) == set(dst_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip"