From 8d7099cd7d1c130692754766206cccc63b46fb36 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 12:50:57 +0200 Subject: [PATCH] fix(profiling): publish preview runs via hf dataset prs --- scripts/ci/run_model_profiling.py | 81 ++++++++++++++++++++------- tests/scripts/test_model_profiling.py | 51 +++++++++++++++++ 2 files changed, 112 insertions(+), 20 deletions(-) diff --git a/scripts/ci/run_model_profiling.py b/scripts/ci/run_model_profiling.py index 7427fd459..2e810e216 100644 --- a/scripts/ci/run_model_profiling.py +++ b/scripts/ci/run_model_profiling.py @@ -18,6 +18,8 @@ from __future__ import annotations import argparse import json +import re +import shutil import subprocess import time from dataclasses import dataclass @@ -25,7 +27,7 @@ from datetime import UTC, datetime from pathlib import Path from typing import Any -from huggingface_hub import HfApi +from huggingface_hub import CommitOperationAdd, HfApi @dataclass(frozen=True) @@ -41,14 +43,32 @@ class UploadTarget: path_in_repo: str +@dataclass(frozen=True) +class UploadResult: + uploaded_paths: dict[str, str] + pr_url: str | None = None + + def utc_timestamp_slug(now: datetime | None = None) -> str: current = now or datetime.now(UTC) return current.strftime("%Y%m%dT%H%M%SZ") -def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str: +def make_hub_file_url( + repo_id: str, + path_in_repo: str, + repo_type: str = "dataset", + revision: str = "main", +) -> str: prefix = "datasets/" if repo_type == "dataset" else "" - return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}" + return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}" + + +def parse_discussion_num(pr_url: str | None) -> int | None: + if not pr_url: + return None + match = re.search(r"/discussions/(\d+)$", pr_url) + return int(match.group(1)) if match else None def upload_targets( @@ -57,21 +77,33 @@ def upload_targets( *, repo_type: str = "dataset", token: str | None = None, - private: bool | None = None, commit_message: str | None = None, -) -> dict[str, str]: + create_pr: bool = False, +) -> UploadResult: api = HfApi(token=token) - uploaded: dict[str, str] = {} - for target in targets: - api.upload_file( - path_or_fileobj=str(target.local_path), - path_in_repo=target.path_in_repo, - repo_id=repo_id, - repo_type=repo_type, - commit_message=commit_message or f"Upload {target.path_in_repo}", + operations = [ + CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path)) + for target in targets + ] + commit = api.create_commit( + repo_id=repo_id, + repo_type=repo_type, + operations=operations, + commit_message=commit_message or f"Upload {len(targets)} profiling artifacts", + revision="main", + create_pr=create_pr, + ) + revision = "main" + pr_num = parse_discussion_num(commit.pr_url) + if create_pr and pr_num is not None: + revision = f"refs/pr/{pr_num}" + uploaded = { + target.path_in_repo: make_hub_file_url( + repo_id, target.path_in_repo, repo_type=repo_type, revision=revision ) - uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type) - return uploaded + for target in targets + } + return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url) def normalize_repo_id(repo: str, hub_org: str) -> str: @@ -205,13 +237,14 @@ def upload_profile_run( row_path: Path, row_path_in_repo: str, artifact_targets: list[UploadTarget], -) -> dict[str, str]: + create_pr: bool = False, +) -> UploadResult: return upload_targets( repo_id=repo_id, targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)], repo_type="dataset", - private=False, commit_message=f"Add model profiling row {row_path_in_repo}", + create_pr=create_pr, ) @@ -221,7 +254,12 @@ def main() -> int: selected = get_selected_names(args.policies, specs) args.output_dir.mkdir(parents=True, exist_ok=True) repo_id = normalize_repo_id(args.results_repo, args.hub_org) - git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + git_executable = shutil.which("git") + if not git_executable: + raise RuntimeError("git executable not found in PATH") + git_commit = ( + args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip() + ) pr_number = int(args.pr_number) if str(args.pr_number).strip() else None for policy_name in selected: @@ -277,13 +315,16 @@ def main() -> int: row_path.write_text(json.dumps(row, indent=2, sort_keys=True)) if args.publish: - uploaded_paths = upload_profile_run( + upload_result = upload_profile_run( repo_id=repo_id, row_path=row_path, row_path_in_repo=row_path_in_repo, artifact_targets=artifact_targets, + create_pr=pr_number is not None, ) - row["uploaded_paths"] = uploaded_paths + row["uploaded_paths"] = upload_result.uploaded_paths + row["publish_pr_url"] = upload_result.pr_url + row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url) row_path.write_text(json.dumps(row, indent=2, sort_keys=True)) print(json.dumps(row, indent=2, sort_keys=True)) diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 7250b6e2c..f15579873 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -110,6 +110,43 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path): assert len(targets) == 7 +def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path): + module = _import_model_profiling_script() + local_path = tmp_path / "profiling_row.json" + local_path.write_text("{}") + captured: dict[str, object] = {} + + class _FakeCommit: + pr_url = "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42" + + class _FakeApi: + def __init__(self, token=None): + captured["token"] = token + + def create_commit(self, **kwargs): + captured.update(kwargs) + return _FakeCommit() + + monkeypatch.setattr(module, "HfApi", _FakeApi) + + result = module.upload_targets( + repo_id="lerobot/model-profiling-history", + targets=[module.UploadTarget(local_path=local_path, path_in_repo="rows/act/run.json")], + create_pr=True, + token="hf_test_token", + ) + + assert captured["repo_id"] == "lerobot/model-profiling-history" + assert captured["repo_type"] == "dataset" + assert captured["revision"] == "main" + assert captured["create_pr"] is True + operations = captured["operations"] + assert len(operations) == 1 + assert operations[0].path_in_repo == "rows/act/run.json" + assert result.pr_url == _FakeCommit.pr_url + assert result.uploaded_paths["rows/act/run.json"].endswith("/resolve/refs/pr/42/rows/act/run.json") + + def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): module = _import_model_profiling_script() @@ -192,6 +229,20 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"] +def test_parse_discussion_num_handles_hf_discussion_urls(): + module = _import_model_profiling_script() + + assert ( + module.parse_discussion_num( + "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42" + ) + == 42 + ) + assert ( + module.parse_discussion_num("https://huggingface.co/datasets/lerobot/model-profiling-history") is None + ) + + def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts