diff --git a/src/lerobot/jobs/hf.py b/src/lerobot/jobs/hf.py index 2cb5518a1..0a2810190 100644 --- a/src/lerobot/jobs/hf.py +++ b/src/lerobot/jobs/hf.py @@ -206,7 +206,8 @@ def _poll_until_done( try: info = inspect_job(job_id=job_id) failures = 0 - stage = info.status.stage.value + # `stage` is an enum in some huggingface_hub versions and a plain str in others. + stage = getattr(info.status.stage, "value", info.status.stage) if stage in _TERMINAL_STAGES: if status_holder is not None: status_holder["message"] = getattr(info.status, "message", None) diff --git a/tests/jobs/test_hf.py b/tests/jobs/test_hf.py index 20e153364..90ebd9c67 100644 --- a/tests/jobs/test_hf.py +++ b/tests/jobs/test_hf.py @@ -41,12 +41,15 @@ def test_resolve_job_tags_always_includes_lerobot_and_dedups(): assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"] -def _fake_inspect(stage_value): - return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value))) +def _fake_inspect(stage_value, *, as_enum=True): + # huggingface_hub returns `stage` as an enum (with `.value`) in some versions and a plain str in others. + stage = SimpleNamespace(value=stage_value) if as_enum else stage_value + return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=stage)) -def test_poll_until_done_returns_terminal_stage(monkeypatch): - monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED")) +@pytest.mark.parametrize("as_enum", [True, False], ids=["enum_stage", "str_stage"]) +def test_poll_until_done_returns_terminal_stage(monkeypatch, as_enum): + monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED", as_enum=as_enum)) done = threading.Event() assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED" assert done.is_set()