mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-27 21:27:21 +00:00
feat(train): also push EMA weights to <repo_id>-ema
When EMA is enabled we eval the EMA weights but only the live weights were pushed to the hub, so the model we benchmark offline differs from the one selected during training. Push the EMA weights to a sibling repo too (non-fatal) so both are fully loadable and the better one can be picked. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1013,6 +1013,28 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
|
||||
# When EMA is on we *eval* the EMA weights but the push above
|
||||
# ships the live weights — they're different models. Push the EMA
|
||||
# weights too, to a sibling ``<repo_id>-ema`` repo, so both are
|
||||
# fully loadable and you can benchmark/deploy whichever is better.
|
||||
# Non-fatal: the live model is already up if this fails.
|
||||
if ema is not None and not (
|
||||
not cfg.is_reward_model_training and cfg.policy.use_peft
|
||||
):
|
||||
ema_model = ema.ema_model
|
||||
ema_repo_id = f"{active_cfg.repo_id}-ema"
|
||||
orig_repo_id = ema_model.config.repo_id
|
||||
try:
|
||||
ema_model.config.repo_id = ema_repo_id
|
||||
ema_model.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(ema_repo_id)
|
||||
postprocessor.push_to_hub(ema_repo_id)
|
||||
logging.info("Pushed EMA weights to %s", ema_repo_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("Failed to push EMA weights to %s: %s", ema_repo_id, exc)
|
||||
finally:
|
||||
ema_model.config.repo_id = orig_repo_id
|
||||
|
||||
# Properly clean up the distributed process group
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
Reference in New Issue
Block a user