feat(train): also push EMA weights to <repo_id>-ema

When EMA is enabled we eval the EMA weights but only the live weights were
pushed to the hub, so the model we benchmark offline differs from the one
selected during training. Push the EMA weights to a sibling repo too
(non-fatal) so both are fully loadable and the better one can be picked.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-24 14:53:44 +02:00
parent ab0147f1ca
commit e1dc741709
+22
View File
@@ -1013,6 +1013,28 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
# When EMA is on we *eval* the EMA weights but the push above
# ships the live weights — they're different models. Push the EMA
# weights too, to a sibling ``<repo_id>-ema`` repo, so both are
# fully loadable and you can benchmark/deploy whichever is better.
# Non-fatal: the live model is already up if this fails.
if ema is not None and not (
not cfg.is_reward_model_training and cfg.policy.use_peft
):
ema_model = ema.ema_model
ema_repo_id = f"{active_cfg.repo_id}-ema"
orig_repo_id = ema_model.config.repo_id
try:
ema_model.config.repo_id = ema_repo_id
ema_model.push_model_to_hub(cfg)
preprocessor.push_to_hub(ema_repo_id)
postprocessor.push_to_hub(ema_repo_id)
logging.info("Pushed EMA weights to %s", ema_repo_id)
except Exception as exc: # noqa: BLE001
logging.warning("Failed to push EMA weights to %s: %s", ema_repo_id, exc)
finally:
ema_model.config.repo_id = orig_repo_id
# Properly clean up the distributed process group
accelerator.wait_for_everyone()
accelerator.end_training()