From 60cbe7185758ec734c820e2df58e0f0945153a5e Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Wed, 24 Jun 2026 11:00:42 +0200 Subject: [PATCH] refactor(jobs): hoist huggingface_hub imports to module level in hf.py huggingface_hub is a core dependency, so the per-function dynamic imports had no lazy-loading rationale. Move them to a single module-level import and update test monkeypatch targets to lerobot.jobs.hf.* accordingly. --- src/lerobot/jobs/hf.py | 18 ++++++++--------- tests/jobs/test_hf.py | 45 +++++++++++++++++------------------------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/src/lerobot/jobs/hf.py b/src/lerobot/jobs/hf.py index 035356a76..1e0b0e7c8 100644 --- a/src/lerobot/jobs/hf.py +++ b/src/lerobot/jobs/hf.py @@ -33,7 +33,15 @@ from pathlib import Path from typing import TYPE_CHECKING import draccus -from huggingface_hub import get_token +from huggingface_hub import ( + HfApi, + create_repo, + fetch_job_logs, + get_token, + inspect_job, + run_job, + upload_file, +) if TYPE_CHECKING: from lerobot.configs.train import TrainPipelineConfig @@ -116,8 +124,6 @@ def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | No def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str: """Upload train_config.json to the model repo and return the repo_id for --config_path.""" - from huggingface_hub import create_repo, upload_file - create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token) with tempfile.TemporaryDirectory() as tmp: config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags) @@ -147,8 +153,6 @@ def _tail_logs( caller can finish as soon as the trained model lands on the Hub, rather than waiting out the platform's post-run finalization (which can add ~30s). """ - from huggingface_hub import fetch_job_logs - printed = 0 while not done.is_set(): try: @@ -190,8 +194,6 @@ def _poll_until_done( is reached and `status_holder` is given, records `status_holder["message"]` (the platform's status message, e.g. "Job timeout"). """ - from huggingface_hub import inspect_job - failures = 0 while not done.is_set(): try: @@ -219,8 +221,6 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None: the job, then either tails logs until completion or detaches immediately. Ctrl-C detaches without cancelling the remote job. """ - from huggingface_hub import HfApi, run_job - from lerobot.jobs.dataset import ensure_dataset_available token = get_token() diff --git a/tests/jobs/test_hf.py b/tests/jobs/test_hf.py index 818e99c88..f6e17c278 100644 --- a/tests/jobs/test_hf.py +++ b/tests/jobs/test_hf.py @@ -44,7 +44,7 @@ def _fake_inspect(stage_value): def test_poll_until_done_returns_terminal_stage(monkeypatch): - monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED")) + monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED")) done = threading.Event() assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED" assert done.is_set() @@ -52,7 +52,7 @@ def test_poll_until_done_returns_terminal_stage(monkeypatch): def test_poll_until_done_exits_when_done_already_set(monkeypatch): # Non-terminal forever; with done pre-set the loop must not block and returns None. - monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING")) + monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("RUNNING")) done = threading.Event() done.set() assert _poll_until_done("j", done, poll_interval=0.01) is None @@ -60,7 +60,7 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch): def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch): monkeypatch.setattr( - "huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom")) + "lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom")) ) done = threading.Event() result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3) @@ -195,8 +195,6 @@ def test_submit_passes_validation_and_submits(monkeypatch): """Regression: repo_id must be set BEFORE cfg.validate() or validation raises.""" from unittest.mock import MagicMock - import huggingface_hub - # Patch get_token monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") @@ -208,7 +206,7 @@ def test_submit_passes_validation_and_submits(monkeypatch): def whoami(self, token=None): return {"name": "alice"} - monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) # ensure_dataset_available returns None; patch it out so no Hub access happens # (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`). @@ -229,7 +227,7 @@ def test_submit_passes_validation_and_submits(monkeypatch): run_job_calls.append(kwargs) return fake_job - monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job) + monkeypatch.setattr("lerobot.jobs.hf.run_job", fake_run_job) cfg = draccus.parse( TrainPipelineConfig, @@ -268,8 +266,6 @@ def test_submit_returns_when_job_completes(monkeypatch): """Non-detach path must RETURN (not hang) once the job reaches a terminal stage.""" from types import SimpleNamespace - import huggingface_hub - monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") class FakeHfApi: @@ -279,21 +275,21 @@ def test_submit_returns_when_job_completes(monkeypatch): def whoami(self, token=None): return {"name": "alice"} - monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) monkeypatch.setattr( "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id ) - monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) # Job is already COMPLETED on the first poll. monkeypatch.setattr( - "huggingface_hub.inspect_job", + "lerobot.jobs.hf.inspect_job", lambda job_id: SimpleNamespace( status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None) ), ) # Log stream ends immediately. - monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(())) + monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(())) cfg = draccus.parse( TrainPipelineConfig, @@ -309,8 +305,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch): """Finish when the model-pushed log appears, even if the job stage never flips.""" from types import SimpleNamespace - import huggingface_hub - monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") class FakeHfApi: @@ -320,21 +314,21 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch): def whoami(self, token=None): return {"name": "alice"} - monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) monkeypatch.setattr( "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id ) - monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) # Job stays RUNNING forever — only the log marker can end the command. monkeypatch.setattr( - "huggingface_hub.inspect_job", + "lerobot.jobs.hf.inspect_job", lambda job_id: SimpleNamespace( status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None) ), ) pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun" - monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line])) + monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line])) cfg = draccus.parse( TrainPipelineConfig, @@ -355,7 +349,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch): def test_submit_raises_when_wandb_enabled_without_key(monkeypatch): """wandb.enable with no key reachable anywhere fails fast, before submitting.""" - import huggingface_hub monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") @@ -366,7 +359,7 @@ def test_submit_raises_when_wandb_enabled_without_key(monkeypatch): def whoami(self, token=None): return {"name": "alice"} - monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None) cfg = draccus.parse( @@ -391,8 +384,6 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch): """A terminal non-COMPLETED stage with no model-pushed marker must raise with the status.""" from types import SimpleNamespace - import huggingface_hub - monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") class FakeHfApi: @@ -402,21 +393,21 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch): def whoami(self, token=None): return {"name": "alice"} - monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) monkeypatch.setattr( "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id ) - monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) # Job fails: a terminal ERROR stage carrying the platform's status message. monkeypatch.setattr( - "huggingface_hub.inspect_job", + "lerobot.jobs.hf.inspect_job", lambda job_id: SimpleNamespace( status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout") ), ) # Logs end without the model-pushed marker. - monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(())) + monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(())) cfg = draccus.parse( TrainPipelineConfig,