mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix(profiling): publish preview runs via hf dataset prs
This commit is contained in:
@@ -18,6 +18,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -25,7 +27,7 @@ from datetime import UTC, datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import CommitOperationAdd, HfApi
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -41,14 +43,32 @@ class UploadTarget:
|
|||||||
path_in_repo: str
|
path_in_repo: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class UploadResult:
|
||||||
|
uploaded_paths: dict[str, str]
|
||||||
|
pr_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def utc_timestamp_slug(now: datetime | None = None) -> str:
|
def utc_timestamp_slug(now: datetime | None = None) -> str:
|
||||||
current = now or datetime.now(UTC)
|
current = now or datetime.now(UTC)
|
||||||
return current.strftime("%Y%m%dT%H%M%SZ")
|
return current.strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
|
||||||
|
|
||||||
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str:
|
def make_hub_file_url(
|
||||||
|
repo_id: str,
|
||||||
|
path_in_repo: str,
|
||||||
|
repo_type: str = "dataset",
|
||||||
|
revision: str = "main",
|
||||||
|
) -> str:
|
||||||
prefix = "datasets/" if repo_type == "dataset" else ""
|
prefix = "datasets/" if repo_type == "dataset" else ""
|
||||||
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}"
|
return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_discussion_num(pr_url: str | None) -> int | None:
|
||||||
|
if not pr_url:
|
||||||
|
return None
|
||||||
|
match = re.search(r"/discussions/(\d+)$", pr_url)
|
||||||
|
return int(match.group(1)) if match else None
|
||||||
|
|
||||||
|
|
||||||
def upload_targets(
|
def upload_targets(
|
||||||
@@ -57,21 +77,33 @@ def upload_targets(
|
|||||||
*,
|
*,
|
||||||
repo_type: str = "dataset",
|
repo_type: str = "dataset",
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
private: bool | None = None,
|
|
||||||
commit_message: str | None = None,
|
commit_message: str | None = None,
|
||||||
) -> dict[str, str]:
|
create_pr: bool = False,
|
||||||
|
) -> UploadResult:
|
||||||
api = HfApi(token=token)
|
api = HfApi(token=token)
|
||||||
uploaded: dict[str, str] = {}
|
operations = [
|
||||||
for target in targets:
|
CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
|
||||||
api.upload_file(
|
for target in targets
|
||||||
path_or_fileobj=str(target.local_path),
|
]
|
||||||
path_in_repo=target.path_in_repo,
|
commit = api.create_commit(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
repo_type=repo_type,
|
repo_type=repo_type,
|
||||||
commit_message=commit_message or f"Upload {target.path_in_repo}",
|
operations=operations,
|
||||||
|
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
|
||||||
|
revision="main",
|
||||||
|
create_pr=create_pr,
|
||||||
|
)
|
||||||
|
revision = "main"
|
||||||
|
pr_num = parse_discussion_num(commit.pr_url)
|
||||||
|
if create_pr and pr_num is not None:
|
||||||
|
revision = f"refs/pr/{pr_num}"
|
||||||
|
uploaded = {
|
||||||
|
target.path_in_repo: make_hub_file_url(
|
||||||
|
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
|
||||||
)
|
)
|
||||||
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type)
|
for target in targets
|
||||||
return uploaded
|
}
|
||||||
|
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
|
||||||
|
|
||||||
|
|
||||||
def normalize_repo_id(repo: str, hub_org: str) -> str:
|
def normalize_repo_id(repo: str, hub_org: str) -> str:
|
||||||
@@ -205,13 +237,14 @@ def upload_profile_run(
|
|||||||
row_path: Path,
|
row_path: Path,
|
||||||
row_path_in_repo: str,
|
row_path_in_repo: str,
|
||||||
artifact_targets: list[UploadTarget],
|
artifact_targets: list[UploadTarget],
|
||||||
) -> dict[str, str]:
|
create_pr: bool = False,
|
||||||
|
) -> UploadResult:
|
||||||
return upload_targets(
|
return upload_targets(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
|
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
private=False,
|
|
||||||
commit_message=f"Add model profiling row {row_path_in_repo}",
|
commit_message=f"Add model profiling row {row_path_in_repo}",
|
||||||
|
create_pr=create_pr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -221,7 +254,12 @@ def main() -> int:
|
|||||||
selected = get_selected_names(args.policies, specs)
|
selected = get_selected_names(args.policies, specs)
|
||||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
|
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
|
||||||
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
|
git_executable = shutil.which("git")
|
||||||
|
if not git_executable:
|
||||||
|
raise RuntimeError("git executable not found in PATH")
|
||||||
|
git_commit = (
|
||||||
|
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
|
||||||
|
)
|
||||||
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
||||||
|
|
||||||
for policy_name in selected:
|
for policy_name in selected:
|
||||||
@@ -277,13 +315,16 @@ def main() -> int:
|
|||||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|
||||||
if args.publish:
|
if args.publish:
|
||||||
uploaded_paths = upload_profile_run(
|
upload_result = upload_profile_run(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
row_path=row_path,
|
row_path=row_path,
|
||||||
row_path_in_repo=row_path_in_repo,
|
row_path_in_repo=row_path_in_repo,
|
||||||
artifact_targets=artifact_targets,
|
artifact_targets=artifact_targets,
|
||||||
|
create_pr=pr_number is not None,
|
||||||
)
|
)
|
||||||
row["uploaded_paths"] = uploaded_paths
|
row["uploaded_paths"] = upload_result.uploaded_paths
|
||||||
|
row["publish_pr_url"] = upload_result.pr_url
|
||||||
|
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
|
||||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|
||||||
print(json.dumps(row, indent=2, sort_keys=True))
|
print(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|||||||
@@ -110,6 +110,43 @@ def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
|
|||||||
assert len(targets) == 7
|
assert len(targets) == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
|
||||||
|
module = _import_model_profiling_script()
|
||||||
|
local_path = tmp_path / "profiling_row.json"
|
||||||
|
local_path.write_text("{}")
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
class _FakeCommit:
|
||||||
|
pr_url = "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
|
||||||
|
|
||||||
|
class _FakeApi:
|
||||||
|
def __init__(self, token=None):
|
||||||
|
captured["token"] = token
|
||||||
|
|
||||||
|
def create_commit(self, **kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return _FakeCommit()
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "HfApi", _FakeApi)
|
||||||
|
|
||||||
|
result = module.upload_targets(
|
||||||
|
repo_id="lerobot/model-profiling-history",
|
||||||
|
targets=[module.UploadTarget(local_path=local_path, path_in_repo="rows/act/run.json")],
|
||||||
|
create_pr=True,
|
||||||
|
token="hf_test_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["repo_id"] == "lerobot/model-profiling-history"
|
||||||
|
assert captured["repo_type"] == "dataset"
|
||||||
|
assert captured["revision"] == "main"
|
||||||
|
assert captured["create_pr"] is True
|
||||||
|
operations = captured["operations"]
|
||||||
|
assert len(operations) == 1
|
||||||
|
assert operations[0].path_in_repo == "rows/act/run.json"
|
||||||
|
assert result.pr_url == _FakeCommit.pr_url
|
||||||
|
assert result.uploaded_paths["rows/act/run.json"].endswith("/resolve/refs/pr/42/rows/act/run.json")
|
||||||
|
|
||||||
|
|
||||||
def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||||
module = _import_model_profiling_script()
|
module = _import_model_profiling_script()
|
||||||
|
|
||||||
@@ -192,6 +229,20 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
|||||||
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_discussion_num_handles_hf_discussion_urls():
|
||||||
|
module = _import_model_profiling_script()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
module.parse_discussion_num(
|
||||||
|
"https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
|
||||||
|
)
|
||||||
|
== 42
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
module.parse_discussion_num("https://huggingface.co/datasets/lerobot/model-profiling-history") is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||||
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts
|
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user