mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat(rl): consolidate HIL-SERL checkpoint into HF-style components
Make and s, add abstract / for algorithm-owned tensors (critics, target nets, ), and persist them as a sibling component next to . Replace the pickled side-file with an enriched carrying both and , so resume restores actor + critics + target nets + temperature + optimizers + RNG + counters from plain HF-standard files.
This commit is contained in:
@@ -515,3 +515,88 @@ def test_make_algorithm_builds_sac():
|
||||
algorithm = make_algorithm(cfg=algo_config, policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.config.utd_ratio == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# state_dict / load_state_dict (algorithm-side resume)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_state_dict_contains_algorithm_owned_tensors():
|
||||
"""state_dict should pack critics, target networks, and log_alpha (no encoder bloat)."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
sd = algorithm.state_dict()
|
||||
|
||||
assert "log_alpha" in sd
|
||||
assert any(k.startswith("critic_ensemble.") for k in sd)
|
||||
assert any(k.startswith("critic_target.") for k in sd)
|
||||
# encoder weights live on the policy and must not be duplicated here.
|
||||
assert not any(".encoder." in k for k in sd)
|
||||
|
||||
|
||||
def test_state_dict_includes_discrete_critic_target_when_present():
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
sd = algorithm.state_dict()
|
||||
assert any(k.startswith("discrete_critic_target.") for k in sd)
|
||||
|
||||
|
||||
def test_load_state_dict_round_trip_restores_critics_and_log_alpha():
|
||||
"""state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly."""
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
src_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
src.make_optimizers_and_scheduler()
|
||||
# Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete).
|
||||
src.update(_batch_iterator(action_dim=7))
|
||||
src.update(_batch_iterator(action_dim=7))
|
||||
|
||||
dst_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
dst.make_optimizers_and_scheduler()
|
||||
|
||||
src_sd = src.state_dict()
|
||||
dst.load_state_dict(src_sd)
|
||||
dst_sd = dst.state_dict()
|
||||
|
||||
assert set(dst_sd) == set(src_sd)
|
||||
for key in src_sd:
|
||||
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip"
|
||||
|
||||
|
||||
def test_load_state_dict_preserves_log_alpha_parameter_identity():
|
||||
"""The temperature optimizer holds a reference to log_alpha; identity must survive load."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
log_alpha_id_before = id(algorithm.log_alpha)
|
||||
optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0])
|
||||
assert log_alpha_id_before == optimizer_param_id
|
||||
|
||||
new_state = algorithm.state_dict()
|
||||
new_state["log_alpha"] = torch.tensor([0.42])
|
||||
algorithm.load_state_dict(new_state)
|
||||
|
||||
assert id(algorithm.log_alpha) == log_alpha_id_before
|
||||
assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before
|
||||
assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42]))
|
||||
|
||||
|
||||
def test_save_pretrained_round_trip_via_disk(tmp_path):
|
||||
"""End-to-end: save_pretrained -> from_pretrained restores tensors and config."""
|
||||
sac_cfg = _make_sac_config()
|
||||
src_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
|
||||
src.make_optimizers_and_scheduler()
|
||||
src.update(_batch_iterator())
|
||||
|
||||
save_dir = tmp_path / "algorithm"
|
||||
src.save_pretrained(save_dir)
|
||||
assert (save_dir / "model.safetensors").is_file()
|
||||
assert (save_dir / "config.json").is_file()
|
||||
|
||||
dst_policy = GaussianActorPolicy(config=sac_cfg)
|
||||
dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy)
|
||||
|
||||
src_sd = src.state_dict()
|
||||
dst_sd = dst.state_dict()
|
||||
assert set(src_sd) == set(dst_sd)
|
||||
for key in src_sd:
|
||||
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip"
|
||||
|
||||
Reference in New Issue
Block a user