fix(profiling): keep ci green when hub publish is unauthorized

This commit is contained in:
Pepijn
2026-04-16 13:07:30 +02:00
parent 8d7099cd7d
commit 6d1a5fca02
2 changed files with 78 additions and 10 deletions
+61
View File
@@ -24,6 +24,7 @@ import sys
from pathlib import Path
import torch
from huggingface_hub.errors import HfHubHTTPError
def _import_model_profiling_script():
@@ -229,6 +230,66 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path):
module = _import_model_profiling_script()
spec_file = tmp_path / "specs.json"
spec_file.write_text(
json.dumps(
{
"act": {
"steps": 1,
"train_args": [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--batch_size=4",
],
}
}
)
)
args = argparse.Namespace(
spec_file=spec_file,
policies=["act"],
output_dir=tmp_path / "results",
hub_org="lerobot",
results_repo="model-profiling-history",
publish=True,
profile_mode="summary",
git_commit="deadbeef",
git_ref="codex/model-profiling",
pr_number="3389",
)
monkeypatch.setattr(module, "parse_args", lambda: args)
def _fake_run(cmd, capture_output, text):
profile_dir = Path(
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
)
profile_dir.mkdir(parents=True, exist_ok=True)
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
monkeypatch.setattr(module.subprocess, "run", _fake_run)
def _fake_upload_profile_run(**kwargs):
response = type("Response", (), {"status_code": 403, "headers": {}, "request": None})()
raise HfHubHTTPError("403 Forbidden: Authorization error.", response=response)
monkeypatch.setattr(module, "upload_profile_run", _fake_upload_profile_run)
assert module.main() == 0
row_paths = list((tmp_path / "results").rglob("profiling_row.json"))
assert len(row_paths) == 1
row = json.loads(row_paths[0].read_text())
assert row["status"] == "success"
assert row["publish_status"] == "failed"
assert "Authorization error" in row["publish_error"]
def test_parse_discussion_num_handles_hf_discussion_urls():
module = _import_model_profiling_script()