Add offline phase hooks to RLAlgorithm base

This commit is contained in:
Khalil Meftah
2026-03-22 22:52:56 +01:00
parent f495054321
commit 05395c8b10
+29
View File
@@ -91,6 +91,35 @@ class RLAlgorithm(abc.ABC):
"""
...
def supports_offline_phase(self) -> bool:
"""Whether this algorithm has an offline pretraining phase.
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
return ``True`` here. The learner checks this before the main online
loop and routes to :meth:`offline_update` accordingly.
"""
return False
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One offline training step (called before any online collection).
Only called when :meth:`supports_offline_phase` returns ``True``.
Uses the same iterator protocol as :meth:`update`.
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement offline_update(). "
"Either override this method or return False from supports_offline_phase()."
)
def transition_to_online(self) -> None: # noqa: B027
"""Called once when switching from offline to online phase.
Use this to freeze modules trained offline, rebuild optimizers for the
online phase, reset step counters, etc.
Default is a no-op; subclasses override when they have an offline phase.
"""
def configure_data_iterator(
self,
data_mixer: DataMixer,