mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 06:37:15 +00:00
refactor(train): use module-level HfApi import in push_checkpoint_to_hub
huggingface_hub is a core dependency; the in-function import was unnecessary. Move HfApi to a module-level import and point the test monkeypatches at lerobot.common.train_utils.HfApi.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
@@ -297,8 +298,6 @@ def push_checkpoint_to_hub(
|
||||
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
|
||||
The model repo is created idempotently.
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
||||
api.upload_folder(
|
||||
|
||||
@@ -155,12 +155,10 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run", private=True)
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is True
|
||||
@@ -175,12 +173,10 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run")
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is None
|
||||
|
||||
Reference in New Issue
Block a user