fix(profiling): publish preview runs via hf dataset prs

This commit is contained in:
Pepijn
2026-04-16 12:50:57 +02:00
parent 516f39685a
commit 8d7099cd7d
2 changed files with 112 additions and 20 deletions
+61 -20
View File
@@ -18,6 +18,8 @@ from __future__ import annotations
import argparse
import json
import re
import shutil
import subprocess
import time
from dataclasses import dataclass
@@ -25,7 +27,7 @@ from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from huggingface_hub import HfApi
from huggingface_hub import CommitOperationAdd, HfApi
@dataclass(frozen=True)
@@ -41,14 +43,32 @@ class UploadTarget:
path_in_repo: str
@dataclass(frozen=True)
class UploadResult:
uploaded_paths: dict[str, str]
pr_url: str | None = None
def utc_timestamp_slug(now: datetime | None = None) -> str:
current = now or datetime.now(UTC)
return current.strftime("%Y%m%dT%H%M%SZ")
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str:
def make_hub_file_url(
repo_id: str,
path_in_repo: str,
repo_type: str = "dataset",
revision: str = "main",
) -> str:
prefix = "datasets/" if repo_type == "dataset" else ""
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}"
return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
def parse_discussion_num(pr_url: str | None) -> int | None:
if not pr_url:
return None
match = re.search(r"/discussions/(\d+)$", pr_url)
return int(match.group(1)) if match else None
def upload_targets(
@@ -57,21 +77,33 @@ def upload_targets(
*,
repo_type: str = "dataset",
token: str | None = None,
private: bool | None = None,
commit_message: str | None = None,
) -> dict[str, str]:
create_pr: bool = False,
) -> UploadResult:
api = HfApi(token=token)
uploaded: dict[str, str] = {}
for target in targets:
api.upload_file(
path_or_fileobj=str(target.local_path),
path_in_repo=target.path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or f"Upload {target.path_in_repo}",
operations = [
CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
for target in targets
]
commit = api.create_commit(
repo_id=repo_id,
repo_type=repo_type,
operations=operations,
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
revision="main",
create_pr=create_pr,
)
revision = "main"
pr_num = parse_discussion_num(commit.pr_url)
if create_pr and pr_num is not None:
revision = f"refs/pr/{pr_num}"
uploaded = {
target.path_in_repo: make_hub_file_url(
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
)
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type)
return uploaded
for target in targets
}
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
def normalize_repo_id(repo: str, hub_org: str) -> str:
@@ -205,13 +237,14 @@ def upload_profile_run(
row_path: Path,
row_path_in_repo: str,
artifact_targets: list[UploadTarget],
) -> dict[str, str]:
create_pr: bool = False,
) -> UploadResult:
return upload_targets(
repo_id=repo_id,
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
repo_type="dataset",
private=False,
commit_message=f"Add model profiling row {row_path_in_repo}",
create_pr=create_pr,
)
@@ -221,7 +254,12 @@ def main() -> int:
selected = get_selected_names(args.policies, specs)
args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
git_executable = shutil.which("git")
if not git_executable:
raise RuntimeError("git executable not found in PATH")
git_commit = (
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
)
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
for policy_name in selected:
@@ -277,13 +315,16 @@ def main() -> int:
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
if args.publish:
uploaded_paths = upload_profile_run(
upload_result = upload_profile_run(
repo_id=repo_id,
row_path=row_path,
row_path_in_repo=row_path_in_repo,
artifact_targets=artifact_targets,
create_pr=pr_number is not None,
)
row["uploaded_paths"] = uploaded_paths
row["uploaded_paths"] = upload_result.uploaded_paths
row["publish_pr_url"] = upload_result.pr_url
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True))