feat: add offline training in learner

This commit is contained in:
Khalil Meftah
2026-03-22 23:00:07 +01:00
parent d9371b9a34
commit 519234a5d8
+33
View File
@@ -373,6 +373,39 @@ def add_actor_information_and_train(
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
offline_steps = getattr(cfg.policy, "offline_steps", 0)
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
offline_mixer = OnlineOfflineMixer(
online_buffer=offline_replay_buffer,
offline_buffer=None,
online_ratio=1.0,
)
offline_iterator = algorithm.configure_data_iterator(
data_mixer=offline_mixer,
batch_size=total_batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
for step in range(offline_steps):
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
return
stats = algorithm.offline_update(offline_iterator)
if step % log_freq == 0:
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
if wandb_logger:
log_dict = stats.to_log_dict()
log_dict["offline_step"] = step
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
algorithm.transition_to_online()
optimizers = algorithm.get_optimizers()
logging.info("[LEARNER] Offline phase complete, transitioned to online")
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested