diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 0b7009357..71b00aae6 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -651,7 +651,7 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \ Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs: -The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag. +The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag. Each pushed checkpoint is tagged with its step (e.g. `--policy.pretrained_revision=010000`), so you can recover a checkpoint by step without looking up its commit sha. diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 5b5646ed3..9d6afcad2 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -296,14 +296,23 @@ def push_checkpoint_to_hub( Called once per save step when save_checkpoint_to_hub is enabled, so a timed-out or crashed run still leaves recoverable checkpoints on the Hub. - The model repo is created idempotently. + The model repo is created idempotently, and the commit is tagged with the + checkpoint step so a checkpoint can be recovered with + --policy.pretrained_revision= instead of a commit sha. """ api = HfApi() api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) - api.upload_folder( + commit = api.upload_folder( folder_path=str(checkpoint_dir), repo_id=repo_id, repo_type="model", path_in_repo=f"checkpoints/{checkpoint_dir.name}", commit_message=f"checkpoint {checkpoint_dir.name}", ) + api.create_tag( + repo_id=repo_id, + tag=checkpoint_dir.name, + revision=commit.oid, + repo_type="model", + exist_ok=True, + ) diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 2d8fc9e6c..461c5f031 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -170,6 +170,14 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch): assert kwargs["path_in_repo"] == "checkpoints/010000" assert kwargs["folder_path"] == str(ckpt) assert kwargs["commit_message"] == "checkpoint 010000" + # A tag named after the checkpoint step is created so the checkpoint can be + # recovered with --policy.pretrained_revision instead of a commit sha. + api.create_tag.assert_called_once() + tag_kwargs = api.create_tag.call_args.kwargs + assert tag_kwargs["tag"] == "010000" + assert tag_kwargs["revision"] == api.upload_folder.return_value.oid + assert tag_kwargs["repo_type"] == "model" + assert tag_kwargs["exist_ok"] is True def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):