refactor(jobs): hoist huggingface_hub imports to module level in hf.py

huggingface_hub is a core dependency, so the per-function dynamic imports
had no lazy-loading rationale. Move them to a single module-level import
and update test monkeypatch targets to lerobot.jobs.hf.* accordingly.
This commit is contained in:
Nicolas Rabault
2026-06-24 11:00:42 +02:00
parent ed8694c67f
commit 60cbe71857
2 changed files with 27 additions and 36 deletions
+9 -9
View File
@@ -33,7 +33,15 @@ from pathlib import Path
from typing import TYPE_CHECKING
import draccus
from huggingface_hub import get_token
from huggingface_hub import (
HfApi,
create_repo,
fetch_job_logs,
get_token,
inspect_job,
run_job,
upload_file,
)
if TYPE_CHECKING:
from lerobot.configs.train import TrainPipelineConfig
@@ -116,8 +124,6 @@ def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | No
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
from huggingface_hub import create_repo, upload_file
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
with tempfile.TemporaryDirectory() as tmp:
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
@@ -147,8 +153,6 @@ def _tail_logs(
caller can finish as soon as the trained model lands on the Hub, rather than
waiting out the platform's post-run finalization (which can add ~30s).
"""
from huggingface_hub import fetch_job_logs
printed = 0
while not done.is_set():
try:
@@ -190,8 +194,6 @@ def _poll_until_done(
is reached and `status_holder` is given, records `status_holder["message"]`
(the platform's status message, e.g. "Job timeout").
"""
from huggingface_hub import inspect_job
failures = 0
while not done.is_set():
try:
@@ -219,8 +221,6 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
the job, then either tails logs until completion or detaches immediately.
Ctrl-C detaches without cancelling the remote job.
"""
from huggingface_hub import HfApi, run_job
from lerobot.jobs.dataset import ensure_dataset_available
token = get_token()
+18 -27
View File
@@ -44,7 +44,7 @@ def _fake_inspect(stage_value):
def test_poll_until_done_returns_terminal_stage(monkeypatch):
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED"))
done = threading.Event()
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
assert done.is_set()
@@ -52,7 +52,7 @@ def test_poll_until_done_returns_terminal_stage(monkeypatch):
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
# Non-terminal forever; with done pre-set the loop must not block and returns None.
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("RUNNING"))
done = threading.Event()
done.set()
assert _poll_until_done("j", done, poll_interval=0.01) is None
@@ -60,7 +60,7 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch):
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
monkeypatch.setattr(
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
)
done = threading.Event()
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
@@ -195,8 +195,6 @@ def test_submit_passes_validation_and_submits(monkeypatch):
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
from unittest.mock import MagicMock
import huggingface_hub
# Patch get_token
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
@@ -208,7 +206,7 @@ def test_submit_passes_validation_and_submits(monkeypatch):
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
# ensure_dataset_available returns None; patch it out so no Hub access happens
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
@@ -229,7 +227,7 @@ def test_submit_passes_validation_and_submits(monkeypatch):
run_job_calls.append(kwargs)
return fake_job
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
monkeypatch.setattr("lerobot.jobs.hf.run_job", fake_run_job)
cfg = draccus.parse(
TrainPipelineConfig,
@@ -268,8 +266,6 @@ def test_submit_returns_when_job_completes(monkeypatch):
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
@@ -279,21 +275,21 @@ def test_submit_returns_when_job_completes(monkeypatch):
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job is already COMPLETED on the first poll.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
"lerobot.jobs.hf.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
),
)
# Log stream ends immediately.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,
@@ -309,8 +305,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
"""Finish when the model-pushed log appears, even if the job stage never flips."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
@@ -320,21 +314,21 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job stays RUNNING forever — only the log marker can end the command.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
"lerobot.jobs.hf.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
),
)
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
cfg = draccus.parse(
TrainPipelineConfig,
@@ -355,7 +349,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
@@ -366,7 +359,7 @@ def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
cfg = draccus.parse(
@@ -391,8 +384,6 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch):
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
@@ -402,21 +393,21 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch):
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job fails: a terminal ERROR stage carrying the platform's status message.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
"lerobot.jobs.hf.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
),
)
# Logs end without the model-pushed marker.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,