diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 3467b4558..a2e0b9996 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -373,6 +373,39 @@ def add_actor_information_and_train( if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id + # ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ── + offline_steps = getattr(cfg.policy, "offline_steps", 0) + if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None: + logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)") + offline_mixer = OnlineOfflineMixer( + online_buffer=offline_replay_buffer, + offline_buffer=None, + online_ratio=1.0, + ) + offline_iterator = algorithm.configure_data_iterator( + data_mixer=offline_mixer, + batch_size=total_batch_size, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + for step in range(offline_steps): + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown during offline phase. Exiting...") + return + + stats = algorithm.offline_update(offline_iterator) + + if step % log_freq == 0: + logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}") + if wandb_logger: + log_dict = stats.to_log_dict() + log_dict["offline_step"] = step + wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step") + + algorithm.transition_to_online() + optimizers = algorithm.get_optimizers() + logging.info("[LEARNER] Offline phase complete, transitioned to online") + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested