feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.
This commit is contained in:
Khalil Meftah
2026-05-08 21:24:23 +02:00
parent b1b2708e2f
commit 23811b720d
8 changed files with 382 additions and 24 deletions
+6
View File
@@ -68,6 +68,12 @@ class _DummyRLAlgorithm(RLAlgorithm):
def load_weights(self, weights, device="cpu") -> None:
_ = (weights, device)
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state_dict, device="cpu") -> None:
_ = (state_dict, device)
class _SimpleMixer:
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):