mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
fix(profiling): publish preview runs via hf dataset prs
This commit is contained in:
@@ -18,6 +18,8 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -25,7 +27,7 @@ from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import CommitOperationAdd, HfApi
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -41,14 +43,32 @@ class UploadTarget:
|
||||
path_in_repo: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
uploaded_paths: dict[str, str]
|
||||
pr_url: str | None = None
|
||||
|
||||
|
||||
def utc_timestamp_slug(now: datetime | None = None) -> str:
|
||||
current = now or datetime.now(UTC)
|
||||
return current.strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
|
||||
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str:
|
||||
def make_hub_file_url(
|
||||
repo_id: str,
|
||||
path_in_repo: str,
|
||||
repo_type: str = "dataset",
|
||||
revision: str = "main",
|
||||
) -> str:
|
||||
prefix = "datasets/" if repo_type == "dataset" else ""
|
||||
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}"
|
||||
return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
|
||||
|
||||
|
||||
def parse_discussion_num(pr_url: str | None) -> int | None:
|
||||
if not pr_url:
|
||||
return None
|
||||
match = re.search(r"/discussions/(\d+)$", pr_url)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def upload_targets(
|
||||
@@ -57,21 +77,33 @@ def upload_targets(
|
||||
*,
|
||||
repo_type: str = "dataset",
|
||||
token: str | None = None,
|
||||
private: bool | None = None,
|
||||
commit_message: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
create_pr: bool = False,
|
||||
) -> UploadResult:
|
||||
api = HfApi(token=token)
|
||||
uploaded: dict[str, str] = {}
|
||||
for target in targets:
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(target.local_path),
|
||||
path_in_repo=target.path_in_repo,
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message or f"Upload {target.path_in_repo}",
|
||||
operations = [
|
||||
CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
|
||||
for target in targets
|
||||
]
|
||||
commit = api.create_commit(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
operations=operations,
|
||||
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
|
||||
revision="main",
|
||||
create_pr=create_pr,
|
||||
)
|
||||
revision = "main"
|
||||
pr_num = parse_discussion_num(commit.pr_url)
|
||||
if create_pr and pr_num is not None:
|
||||
revision = f"refs/pr/{pr_num}"
|
||||
uploaded = {
|
||||
target.path_in_repo: make_hub_file_url(
|
||||
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
|
||||
)
|
||||
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type)
|
||||
return uploaded
|
||||
for target in targets
|
||||
}
|
||||
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
|
||||
|
||||
|
||||
def normalize_repo_id(repo: str, hub_org: str) -> str:
|
||||
@@ -205,13 +237,14 @@ def upload_profile_run(
|
||||
row_path: Path,
|
||||
row_path_in_repo: str,
|
||||
artifact_targets: list[UploadTarget],
|
||||
) -> dict[str, str]:
|
||||
create_pr: bool = False,
|
||||
) -> UploadResult:
|
||||
return upload_targets(
|
||||
repo_id=repo_id,
|
||||
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
|
||||
repo_type="dataset",
|
||||
private=False,
|
||||
commit_message=f"Add model profiling row {row_path_in_repo}",
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
|
||||
@@ -221,7 +254,12 @@ def main() -> int:
|
||||
selected = get_selected_names(args.policies, specs)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
|
||||
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
|
||||
git_executable = shutil.which("git")
|
||||
if not git_executable:
|
||||
raise RuntimeError("git executable not found in PATH")
|
||||
git_commit = (
|
||||
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
|
||||
)
|
||||
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
||||
|
||||
for policy_name in selected:
|
||||
@@ -277,13 +315,16 @@ def main() -> int:
|
||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
if args.publish:
|
||||
uploaded_paths = upload_profile_run(
|
||||
upload_result = upload_profile_run(
|
||||
repo_id=repo_id,
|
||||
row_path=row_path,
|
||||
row_path_in_repo=row_path_in_repo,
|
||||
artifact_targets=artifact_targets,
|
||||
create_pr=pr_number is not None,
|
||||
)
|
||||
row["uploaded_paths"] = uploaded_paths
|
||||
row["uploaded_paths"] = upload_result.uploaded_paths
|
||||
row["publish_pr_url"] = upload_result.pr_url
|
||||
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
|
||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
print(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
Reference in New Issue
Block a user