feat(train): tag each pushed checkpoint with its step

Address review feedback on #3856: pushing a checkpoint to the Hub now
also creates a tag named after the checkpoint step, so a checkpoint can
be recovered with --policy.pretrained_revision=<step> instead of having
to look up its commit sha.
This commit is contained in:
Nicolas Rabault
2026-06-25 16:48:15 +02:00
parent 3c8e54dcfa
commit 209685609d
3 changed files with 20 additions and 3 deletions
+1 -1
View File
@@ -651,7 +651,7 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag.
The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag. Each pushed checkpoint is tagged with its step (e.g. `--policy.pretrained_revision=010000`), so you can recover a checkpoint by step without looking up its commit sha.
<hfoptions id="eval">
<hfoption id="Base mode (no recording)">
+11 -2
View File
@@ -296,14 +296,23 @@ def push_checkpoint_to_hub(
Called once per save step when save_checkpoint_to_hub is enabled, so a
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
The model repo is created idempotently.
The model repo is created idempotently, and the commit is tagged with the
checkpoint step so a checkpoint can be recovered with
--policy.pretrained_revision=<step> instead of a commit sha.
"""
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
api.upload_folder(
commit = api.upload_folder(
folder_path=str(checkpoint_dir),
repo_id=repo_id,
repo_type="model",
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
commit_message=f"checkpoint {checkpoint_dir.name}",
)
api.create_tag(
repo_id=repo_id,
tag=checkpoint_dir.name,
revision=commit.oid,
repo_type="model",
exist_ok=True,
)
+8
View File
@@ -170,6 +170,14 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
assert kwargs["path_in_repo"] == "checkpoints/010000"
assert kwargs["folder_path"] == str(ckpt)
assert kwargs["commit_message"] == "checkpoint 010000"
# A tag named after the checkpoint step is created so the checkpoint can be
# recovered with --policy.pretrained_revision instead of a commit sha.
api.create_tag.assert_called_once()
tag_kwargs = api.create_tag.call_args.kwargs
assert tag_kwargs["tag"] == "010000"
assert tag_kwargs["revision"] == api.upload_folder.return_value.oid
assert tag_kwargs["repo_type"] == "model"
assert tag_kwargs["exist_ok"] is True
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):