mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 04:07:02 +00:00
refactor(jobs): hoist huggingface_hub imports to module level in hf.py
huggingface_hub is a core dependency, so the per-function dynamic imports had no lazy-loading rationale. Move them to a single module-level import and update test monkeypatch targets to lerobot.jobs.hf.* accordingly.
This commit is contained in:
@@ -33,7 +33,15 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import get_token
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
create_repo,
|
||||
fetch_job_logs,
|
||||
get_token,
|
||||
inspect_job,
|
||||
run_job,
|
||||
upload_file,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -116,8 +124,6 @@ def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | No
|
||||
|
||||
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
|
||||
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
|
||||
from huggingface_hub import create_repo, upload_file
|
||||
|
||||
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
|
||||
@@ -147,8 +153,6 @@ def _tail_logs(
|
||||
caller can finish as soon as the trained model lands on the Hub, rather than
|
||||
waiting out the platform's post-run finalization (which can add ~30s).
|
||||
"""
|
||||
from huggingface_hub import fetch_job_logs
|
||||
|
||||
printed = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
@@ -190,8 +194,6 @@ def _poll_until_done(
|
||||
is reached and `status_holder` is given, records `status_holder["message"]`
|
||||
(the platform's status message, e.g. "Job timeout").
|
||||
"""
|
||||
from huggingface_hub import inspect_job
|
||||
|
||||
failures = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
@@ -219,8 +221,6 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
the job, then either tails logs until completion or detaches immediately.
|
||||
Ctrl-C detaches without cancelling the remote job.
|
||||
"""
|
||||
from huggingface_hub import HfApi, run_job
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
token = get_token()
|
||||
|
||||
+18
-27
@@ -44,7 +44,7 @@ def _fake_inspect(stage_value):
|
||||
|
||||
|
||||
def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED"))
|
||||
done = threading.Event()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
|
||||
assert done.is_set()
|
||||
@@ -52,7 +52,7 @@ def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
|
||||
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
# Non-terminal forever; with done pre-set the loop must not block and returns None.
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("RUNNING"))
|
||||
done = threading.Event()
|
||||
done.set()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
@@ -60,7 +60,7 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
@@ -195,8 +195,6 @@ def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
# Patch get_token
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
@@ -208,7 +206,7 @@ def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
|
||||
# ensure_dataset_available returns None; patch it out so no Hub access happens
|
||||
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
|
||||
@@ -229,7 +227,7 @@ def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
run_job_calls.append(kwargs)
|
||||
return fake_job
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", fake_run_job)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
@@ -268,8 +266,6 @@ def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
@@ -279,21 +275,21 @@ def test_submit_returns_when_job_completes(monkeypatch):
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job is already COMPLETED on the first poll.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
|
||||
),
|
||||
)
|
||||
# Log stream ends immediately.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
@@ -309,8 +305,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
"""Finish when the model-pushed log appears, even if the job stage never flips."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
@@ -320,21 +314,21 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job stays RUNNING forever — only the log marker can end the command.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
|
||||
),
|
||||
)
|
||||
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
@@ -355,7 +349,6 @@ def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
|
||||
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
@@ -366,7 +359,7 @@ def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
|
||||
|
||||
cfg = draccus.parse(
|
||||
@@ -391,8 +384,6 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
@@ -402,21 +393,21 @@ def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job fails: a terminal ERROR stage carrying the platform's status message.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
|
||||
),
|
||||
)
|
||||
# Logs end without the model-pushed marker.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
|
||||
Reference in New Issue
Block a user