fix(profiling): publish preview runs via hf dataset prs

This commit is contained in:
Pepijn
2026-04-16 12:50:57 +02:00
parent 516f39685a
commit 8d7099cd7d
2 changed files with 112 additions and 20 deletions
+59 -18
View File
@@ -18,6 +18,8 @@ from __future__ import annotations
import argparse
import json
import re
import shutil
import subprocess
import time
from dataclasses import dataclass
@@ -25,7 +27,7 @@ from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from huggingface_hub import HfApi
from huggingface_hub import CommitOperationAdd, HfApi
@dataclass(frozen=True)
@@ -41,14 +43,32 @@ class UploadTarget:
path_in_repo: str
@dataclass(frozen=True)
class UploadResult:
uploaded_paths: dict[str, str]
pr_url: str | None = None
def utc_timestamp_slug(now: datetime | None = None) -> str:
current = now or datetime.now(UTC)
return current.strftime("%Y%m%dT%H%M%SZ")
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str:
def make_hub_file_url(
repo_id: str,
path_in_repo: str,
repo_type: str = "dataset",
revision: str = "main",
) -> str:
prefix = "datasets/" if repo_type == "dataset" else ""
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}"
return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
def parse_discussion_num(pr_url: str | None) -> int | None:
if not pr_url:
return None
match = re.search(r"/discussions/(\d+)$", pr_url)
return int(match.group(1)) if match else None
def upload_targets(
@@ -57,21 +77,33 @@ def upload_targets(
*,
repo_type: str = "dataset",
token: str | None = None,
private: bool | None = None,
commit_message: str | None = None,
) -> dict[str, str]:
create_pr: bool = False,
) -> UploadResult:
api = HfApi(token=token)
uploaded: dict[str, str] = {}
for target in targets:
api.upload_file(
path_or_fileobj=str(target.local_path),
path_in_repo=target.path_in_repo,
operations = [
CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
for target in targets
]
commit = api.create_commit(
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or f"Upload {target.path_in_repo}",
operations=operations,
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
revision="main",
create_pr=create_pr,
)
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type)
return uploaded
revision = "main"
pr_num = parse_discussion_num(commit.pr_url)
if create_pr and pr_num is not None:
revision = f"refs/pr/{pr_num}"
uploaded = {
target.path_in_repo: make_hub_file_url(
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
)
for target in targets
}
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
def normalize_repo_id(repo: str, hub_org: str) -> str:
@@ -205,13 +237,14 @@ def upload_profile_run(
row_path: Path,
row_path_in_repo: str,
artifact_targets: list[UploadTarget],
) -> dict[str, str]:
create_pr: bool = False,
) -> UploadResult:
return upload_targets(
repo_id=repo_id,
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
repo_type="dataset",
private=False,
commit_message=f"Add model profiling row {row_path_in_repo}",
create_pr=create_pr,
)
@@ -221,7 +254,12 @@ def main() -> int:
selected = get_selected_names(args.policies, specs)
args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
git_executable = shutil.which("git")
if not git_executable:
raise RuntimeError("git executable not found in PATH")
git_commit = (
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
)
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
for policy_name in selected:
@@ -277,13 +315,16 @@ def main() -> int:
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
if args.publish:
uploaded_paths = upload_profile_run(
upload_result = upload_profile_run(
repo_id=repo_id,
row_path=row_path,
row_path_in_repo=row_path_in_repo,
artifact_targets=artifact_targets,
create_pr=pr_number is not None,
)
row["uploaded_paths"] = uploaded_paths
row["uploaded_paths"] = upload_result.uploaded_paths
row["publish_pr_url"] = upload_result.pr_url
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True))
+51
View File
@@ -110,6 +110,43 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
assert len(targets) == 7
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
module = _import_model_profiling_script()
local_path = tmp_path / "profiling_row.json"
local_path.write_text("{}")
captured: dict[str, object] = {}
class _FakeCommit:
pr_url = "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
class _FakeApi:
def __init__(self, token=None):
captured["token"] = token
def create_commit(self, **kwargs):
captured.update(kwargs)
return _FakeCommit()
monkeypatch.setattr(module, "HfApi", _FakeApi)
result = module.upload_targets(
repo_id="lerobot/model-profiling-history",
targets=[module.UploadTarget(local_path=local_path, path_in_repo="rows/act/run.json")],
create_pr=True,
token="hf_test_token",
)
assert captured["repo_id"] == "lerobot/model-profiling-history"
assert captured["repo_type"] == "dataset"
assert captured["revision"] == "main"
assert captured["create_pr"] is True
operations = captured["operations"]
assert len(operations) == 1
assert operations[0].path_in_repo == "rows/act/run.json"
assert result.pr_url == _FakeCommit.pr_url
assert result.uploaded_paths["rows/act/run.json"].endswith("/resolve/refs/pr/42/rows/act/run.json")
def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
module = _import_model_profiling_script()
@@ -192,6 +229,20 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
def test_parse_discussion_num_handles_hf_discussion_urls():
module = _import_model_profiling_script()
assert (
module.parse_discussion_num(
"https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
)
== 42
)
assert (
module.parse_discussion_num("https://huggingface.co/datasets/lerobot/model-profiling-history") is None
)
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts