fix(profiling): publish preview runs via hf dataset prs

This commit is contained in:
Pepijn
2026-04-16 12:50:57 +02:00
parent 516f39685a
commit 8d7099cd7d
2 changed files with 112 additions and 20 deletions
+61 -20
View File
@@ -18,6 +18,8 @@ from __future__ import annotations
import argparse import argparse
import json import json
import re
import shutil
import subprocess import subprocess
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@@ -25,7 +27,7 @@ from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from huggingface_hub import HfApi from huggingface_hub import CommitOperationAdd, HfApi
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -41,14 +43,32 @@ class UploadTarget:
path_in_repo: str path_in_repo: str
@dataclass(frozen=True)
class UploadResult:
uploaded_paths: dict[str, str]
pr_url: str | None = None
def utc_timestamp_slug(now: datetime | None = None) -> str: def utc_timestamp_slug(now: datetime | None = None) -> str:
current = now or datetime.now(UTC) current = now or datetime.now(UTC)
return current.strftime("%Y%m%dT%H%M%SZ") return current.strftime("%Y%m%dT%H%M%SZ")
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str: def make_hub_file_url(
repo_id: str,
path_in_repo: str,
repo_type: str = "dataset",
revision: str = "main",
) -> str:
prefix = "datasets/" if repo_type == "dataset" else "" prefix = "datasets/" if repo_type == "dataset" else ""
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}" return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
def parse_discussion_num(pr_url: str | None) -> int | None:
if not pr_url:
return None
match = re.search(r"/discussions/(\d+)$", pr_url)
return int(match.group(1)) if match else None
def upload_targets( def upload_targets(
@@ -57,21 +77,33 @@ def upload_targets(
*, *,
repo_type: str = "dataset", repo_type: str = "dataset",
token: str | None = None, token: str | None = None,
private: bool | None = None,
commit_message: str | None = None, commit_message: str | None = None,
) -> dict[str, str]: create_pr: bool = False,
) -> UploadResult:
api = HfApi(token=token) api = HfApi(token=token)
uploaded: dict[str, str] = {} operations = [
for target in targets: CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
api.upload_file( for target in targets
path_or_fileobj=str(target.local_path), ]
path_in_repo=target.path_in_repo, commit = api.create_commit(
repo_id=repo_id, repo_id=repo_id,
repo_type=repo_type, repo_type=repo_type,
commit_message=commit_message or f"Upload {target.path_in_repo}", operations=operations,
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
revision="main",
create_pr=create_pr,
)
revision = "main"
pr_num = parse_discussion_num(commit.pr_url)
if create_pr and pr_num is not None:
revision = f"refs/pr/{pr_num}"
uploaded = {
target.path_in_repo: make_hub_file_url(
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
) )
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type) for target in targets
return uploaded }
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
def normalize_repo_id(repo: str, hub_org: str) -> str: def normalize_repo_id(repo: str, hub_org: str) -> str:
@@ -205,13 +237,14 @@ def upload_profile_run(
row_path: Path, row_path: Path,
row_path_in_repo: str, row_path_in_repo: str,
artifact_targets: list[UploadTarget], artifact_targets: list[UploadTarget],
) -> dict[str, str]: create_pr: bool = False,
) -> UploadResult:
return upload_targets( return upload_targets(
repo_id=repo_id, repo_id=repo_id,
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)], targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
repo_type="dataset", repo_type="dataset",
private=False,
commit_message=f"Add model profiling row {row_path_in_repo}", commit_message=f"Add model profiling row {row_path_in_repo}",
create_pr=create_pr,
) )
@@ -221,7 +254,12 @@ def main() -> int:
selected = get_selected_names(args.policies, specs) selected = get_selected_names(args.policies, specs)
args.output_dir.mkdir(parents=True, exist_ok=True) args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = normalize_repo_id(args.results_repo, args.hub_org) repo_id = normalize_repo_id(args.results_repo, args.hub_org)
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() git_executable = shutil.which("git")
if not git_executable:
raise RuntimeError("git executable not found in PATH")
git_commit = (
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
)
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
for policy_name in selected: for policy_name in selected:
@@ -277,13 +315,16 @@ def main() -> int:
row_path.write_text(json.dumps(row, indent=2, sort_keys=True)) row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
if args.publish: if args.publish:
uploaded_paths = upload_profile_run( upload_result = upload_profile_run(
repo_id=repo_id, repo_id=repo_id,
row_path=row_path, row_path=row_path,
row_path_in_repo=row_path_in_repo, row_path_in_repo=row_path_in_repo,
artifact_targets=artifact_targets, artifact_targets=artifact_targets,
create_pr=pr_number is not None,
) )
row["uploaded_paths"] = uploaded_paths row["uploaded_paths"] = upload_result.uploaded_paths
row["publish_pr_url"] = upload_result.pr_url
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
row_path.write_text(json.dumps(row, indent=2, sort_keys=True)) row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True)) print(json.dumps(row, indent=2, sort_keys=True))
+51
View File
@@ -110,6 +110,43 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
assert len(targets) == 7 assert len(targets) == 7
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
module = _import_model_profiling_script()
local_path = tmp_path / "profiling_row.json"
local_path.write_text("{}")
captured: dict[str, object] = {}
class _FakeCommit:
pr_url = "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
class _FakeApi:
def __init__(self, token=None):
captured["token"] = token
def create_commit(self, **kwargs):
captured.update(kwargs)
return _FakeCommit()
monkeypatch.setattr(module, "HfApi", _FakeApi)
result = module.upload_targets(
repo_id="lerobot/model-profiling-history",
targets=[module.UploadTarget(local_path=local_path, path_in_repo="rows/act/run.json")],
create_pr=True,
token="hf_test_token",
)
assert captured["repo_id"] == "lerobot/model-profiling-history"
assert captured["repo_type"] == "dataset"
assert captured["revision"] == "main"
assert captured["create_pr"] is True
operations = captured["operations"]
assert len(operations) == 1
assert operations[0].path_in_repo == "rows/act/run.json"
assert result.pr_url == _FakeCommit.pr_url
assert result.uploaded_paths["rows/act/run.json"].endswith("/resolve/refs/pr/42/rows/act/run.json")
def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
module = _import_model_profiling_script() module = _import_model_profiling_script()
@@ -192,6 +229,20 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"] assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
def test_parse_discussion_num_handles_hf_discussion_urls():
module = _import_model_profiling_script()
assert (
module.parse_discussion_num(
"https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
)
== 42
)
assert (
module.parse_discussion_num("https://huggingface.co/datasets/lerobot/model-profiling-history") is None
)
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts