mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 22:27:14 +00:00
feat(train): tag each pushed checkpoint with its step
Address review feedback on #3856: pushing a checkpoint to the Hub now also creates a tag named after the checkpoint step, so a checkpoint can be recovered with --policy.pretrained_revision=<step> instead of having to look up its commit sha.
This commit is contained in:
@@ -651,7 +651,7 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
|
||||
The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag.
|
||||
The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag. Each pushed checkpoint is tagged with its step (e.g. `--policy.pretrained_revision=010000`), so you can recover a checkpoint by step without looking up its commit sha.
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
|
||||
@@ -296,14 +296,23 @@ def push_checkpoint_to_hub(
|
||||
|
||||
Called once per save step when save_checkpoint_to_hub is enabled, so a
|
||||
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
|
||||
The model repo is created idempotently.
|
||||
The model repo is created idempotently, and the commit is tagged with the
|
||||
checkpoint step so a checkpoint can be recovered with
|
||||
--policy.pretrained_revision=<step> instead of a commit sha.
|
||||
"""
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
||||
api.upload_folder(
|
||||
commit = api.upload_folder(
|
||||
folder_path=str(checkpoint_dir),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
|
||||
commit_message=f"checkpoint {checkpoint_dir.name}",
|
||||
)
|
||||
api.create_tag(
|
||||
repo_id=repo_id,
|
||||
tag=checkpoint_dir.name,
|
||||
revision=commit.oid,
|
||||
repo_type="model",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
@@ -170,6 +170,14 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
assert kwargs["path_in_repo"] == "checkpoints/010000"
|
||||
assert kwargs["folder_path"] == str(ckpt)
|
||||
assert kwargs["commit_message"] == "checkpoint 010000"
|
||||
# A tag named after the checkpoint step is created so the checkpoint can be
|
||||
# recovered with --policy.pretrained_revision instead of a commit sha.
|
||||
api.create_tag.assert_called_once()
|
||||
tag_kwargs = api.create_tag.call_args.kwargs
|
||||
assert tag_kwargs["tag"] == "010000"
|
||||
assert tag_kwargs["revision"] == api.upload_folder.return_value.oid
|
||||
assert tag_kwargs["repo_type"] == "model"
|
||||
assert tag_kwargs["exist_ok"] is True
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
|
||||
|
||||
Reference in New Issue
Block a user