mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fix(profiling): keep ci green when hub publish is unauthorized
This commit is contained in:
@@ -24,6 +24,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
|
||||
def _import_model_profiling_script():
|
||||
@@ -229,6 +230,66 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
||||
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
||||
|
||||
|
||||
def test_model_profiling_publish_failure_is_recorded_without_failing(monkeypatch, tmp_path):
|
||||
module = _import_model_profiling_script()
|
||||
|
||||
spec_file = tmp_path / "specs.json"
|
||||
spec_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"act": {
|
||||
"steps": 1,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=act",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=4",
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
args = argparse.Namespace(
|
||||
spec_file=spec_file,
|
||||
policies=["act"],
|
||||
output_dir=tmp_path / "results",
|
||||
hub_org="lerobot",
|
||||
results_repo="model-profiling-history",
|
||||
publish=True,
|
||||
profile_mode="summary",
|
||||
git_commit="deadbeef",
|
||||
git_ref="codex/model-profiling",
|
||||
pr_number="3389",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(module, "parse_args", lambda: args)
|
||||
|
||||
def _fake_run(cmd, capture_output, text):
|
||||
profile_dir = Path(
|
||||
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
|
||||
)
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
|
||||
|
||||
monkeypatch.setattr(module.subprocess, "run", _fake_run)
|
||||
|
||||
def _fake_upload_profile_run(**kwargs):
|
||||
response = type("Response", (), {"status_code": 403, "headers": {}, "request": None})()
|
||||
raise HfHubHTTPError("403 Forbidden: Authorization error.", response=response)
|
||||
|
||||
monkeypatch.setattr(module, "upload_profile_run", _fake_upload_profile_run)
|
||||
|
||||
assert module.main() == 0
|
||||
|
||||
row_paths = list((tmp_path / "results").rglob("profiling_row.json"))
|
||||
assert len(row_paths) == 1
|
||||
row = json.loads(row_paths[0].read_text())
|
||||
assert row["status"] == "success"
|
||||
assert row["publish_status"] == "failed"
|
||||
assert "Authorization error" in row["publish_error"]
|
||||
|
||||
|
||||
def test_parse_discussion_num_handles_hf_discussion_urls():
|
||||
module = _import_model_profiling_script()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user