diff --git a/src/lerobot/scripts/lerobot_annotate.py b/src/lerobot/scripts/lerobot_annotate.py index 86d3ab3fa..7fee1f052 100644 --- a/src/lerobot/scripts/lerobot_annotate.py +++ b/src/lerobot/scripts/lerobot_annotate.py @@ -140,7 +140,7 @@ def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None: exist_ok=True, ) print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True) - api.upload_folder( + commit_info = api.upload_folder( folder_path=str(root), repo_id=repo_id, repo_type="dataset", @@ -169,13 +169,18 @@ def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None: version_tag = ds_version except Exception as exc: # noqa: BLE001 print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True) + revision = getattr(commit_info, "oid", None) + tag_kwargs = { + "repo_id": repo_id, + "tag": version_tag, + "repo_type": "dataset", + "exist_ok": True, + } + if revision is not None: + tag_kwargs["revision"] = revision + try: - api.create_tag( - repo_id=repo_id, - tag=version_tag, - repo_type="dataset", - exist_ok=True, - ) + api.create_tag(**tag_kwargs) print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True) except Exception as exc: # noqa: BLE001 print( diff --git a/tests/scripts/test_lerobot_annotate.py b/tests/scripts/test_lerobot_annotate.py new file mode 100644 index 000000000..c98ee7cb3 --- /dev/null +++ b/tests/scripts/test_lerobot_annotate.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +import json +from types import SimpleNamespace + + +def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch): + from lerobot.scripts.lerobot_annotate import _push_to_hub + + root = tmp_path / "dataset" + (root / "meta").mkdir(parents=True) + (root / "meta" / "info.json").write_text(json.dumps({"codebase_version": "v3.0"})) + + calls = {} + + class FakeHfApi: + def create_repo(self, **kwargs): + calls["create_repo"] = kwargs + + def upload_folder(self, **kwargs): + calls["upload_folder"] = kwargs + return SimpleNamespace(oid="abc123") + + def create_tag(self, **kwargs): + calls["create_tag"] = kwargs + + monkeypatch.setattr("huggingface_hub.HfApi", FakeHfApi) + + cfg = SimpleNamespace( + repo_id="source/dataset", + dest_repo_id="annotated/dataset", + push_private=True, + push_commit_message=None, + ) + + _push_to_hub(root, cfg) + + assert calls["create_repo"] == { + "repo_id": "annotated/dataset", + "repo_type": "dataset", + "private": True, + "exist_ok": True, + } + assert calls["upload_folder"]["repo_id"] == "annotated/dataset" + assert calls["create_tag"] == { + "repo_id": "annotated/dataset", + "tag": "v3.0", + "repo_type": "dataset", + "exist_ok": True, + "revision": "abc123", + }