mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix(profiling): keep ci green when hub publish is unauthorized
This commit is contained in:
@@ -28,6 +28,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from huggingface_hub import CommitOperationAdd, HfApi
|
from huggingface_hub import CommitOperationAdd, HfApi
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -315,16 +316,22 @@ def main() -> int:
|
|||||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|
||||||
if args.publish:
|
if args.publish:
|
||||||
upload_result = upload_profile_run(
|
try:
|
||||||
repo_id=repo_id,
|
upload_result = upload_profile_run(
|
||||||
row_path=row_path,
|
repo_id=repo_id,
|
||||||
row_path_in_repo=row_path_in_repo,
|
row_path=row_path,
|
||||||
artifact_targets=artifact_targets,
|
row_path_in_repo=row_path_in_repo,
|
||||||
create_pr=pr_number is not None,
|
artifact_targets=artifact_targets,
|
||||||
)
|
create_pr=pr_number is not None,
|
||||||
row["uploaded_paths"] = upload_result.uploaded_paths
|
)
|
||||||
row["publish_pr_url"] = upload_result.pr_url
|
except HfHubHTTPError as exc:
|
||||||
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
|
row["publish_status"] = "failed"
|
||||||
|
row["publish_error"] = str(exc)
|
||||||
|
else:
|
||||||
|
row["publish_status"] = "success"
|
||||||
|
row["uploaded_paths"] = upload_result.uploaded_paths
|
||||||
|
row["publish_pr_url"] = upload_result.pr_url
|
||||||
|
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
|
||||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|
||||||
print(json.dumps(row, indent=2, sort_keys=True))
|
print(json.dumps(row, indent=2, sort_keys=True))
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
|
||||||
def _import_model_profiling_script():
|
def _import_model_profiling_script():
|
||||||
@@ -229,6 +230,66 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
|||||||
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path):
|
||||||
|
module = _import_model_profiling_script()
|
||||||
|
|
||||||
|
spec_file = tmp_path / "specs.json"
|
||||||
|
spec_file.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"act": {
|
||||||
|
"steps": 1,
|
||||||
|
"train_args": [
|
||||||
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
|
"--dataset.episodes=[0]",
|
||||||
|
"--policy.type=act",
|
||||||
|
"--policy.device=cuda",
|
||||||
|
"--batch_size=4",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
args = argparse.Namespace(
|
||||||
|
spec_file=spec_file,
|
||||||
|
policies=["act"],
|
||||||
|
output_dir=tmp_path / "results",
|
||||||
|
hub_org="lerobot",
|
||||||
|
results_repo="model-profiling-history",
|
||||||
|
publish=True,
|
||||||
|
profile_mode="summary",
|
||||||
|
git_commit="deadbeef",
|
||||||
|
git_ref="codex/model-profiling",
|
||||||
|
pr_number="3389",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "parse_args", lambda: args)
|
||||||
|
|
||||||
|
def _fake_run(cmd, capture_output, text):
|
||||||
|
profile_dir = Path(
|
||||||
|
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
|
||||||
|
)
|
||||||
|
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
|
||||||
|
|
||||||
|
monkeypatch.setattr(module.subprocess, "run", _fake_run)
|
||||||
|
|
||||||
|
def _fake_upload_profile_run(**kwargs):
|
||||||
|
response = type("Response", (), {"status_code": 403, "headers": {}, "request": None})()
|
||||||
|
raise HfHubHTTPError("403 Forbidden: Authorization error.", response=response)
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "upload_profile_run", _fake_upload_profile_run)
|
||||||
|
|
||||||
|
assert module.main() == 0
|
||||||
|
|
||||||
|
row_paths = list((tmp_path / "results").rglob("profiling_row.json"))
|
||||||
|
assert len(row_paths) == 1
|
||||||
|
row = json.loads(row_paths[0].read_text())
|
||||||
|
assert row["status"] == "success"
|
||||||
|
assert row["publish_status"] == "failed"
|
||||||
|
assert "Authorization error" in row["publish_error"]
|
||||||
|
|
||||||
|
|
||||||
def test_parse_discussion_num_handles_hf_discussion_urls():
|
def test_parse_discussion_num_handles_hf_discussion_urls():
|
||||||
module = _import_model_profiling_script()
|
module = _import_model_profiling_script()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user