diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9987556f1..1863b0db6 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -1013,6 +1013,28 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): preprocessor.push_to_hub(active_cfg.repo_id) postprocessor.push_to_hub(active_cfg.repo_id) + # When EMA is on we *eval* the EMA weights but the push above + # ships the live weights — they're different models. Push the EMA + # weights too, to a sibling ``-ema`` repo, so both are + # fully loadable and you can benchmark/deploy whichever is better. + # Non-fatal: the live model is already up if this fails. + if ema is not None and not ( + not cfg.is_reward_model_training and cfg.policy.use_peft + ): + ema_model = ema.ema_model + ema_repo_id = f"{active_cfg.repo_id}-ema" + orig_repo_id = ema_model.config.repo_id + try: + ema_model.config.repo_id = ema_repo_id + ema_model.push_model_to_hub(cfg) + preprocessor.push_to_hub(ema_repo_id) + postprocessor.push_to_hub(ema_repo_id) + logging.info("Pushed EMA weights to %s", ema_repo_id) + except Exception as exc: # noqa: BLE001 + logging.warning("Failed to push EMA weights to %s: %s", ema_repo_id, exc) + finally: + ema_model.config.repo_id = orig_repo_id + # Properly clean up the distributed process group accelerator.wait_for_everyone() accelerator.end_training()