refactor(train): use module-level HfApi import in push_checkpoint_to_hub

huggingface_hub is a core dependency; the in-function import was
unnecessary. Move HfApi to a module-level import and point the test
monkeypatches at lerobot.common.train_utils.HfApi.
This commit is contained in:
Nicolas Rabault
2026-06-24 11:03:04 +02:00
parent 1172d7b3e2
commit dec1c6ddc2
2 changed files with 3 additions and 8 deletions
+1 -2
View File
@@ -15,6 +15,7 @@
# limitations under the License.
from pathlib import Path
from huggingface_hub import HfApi
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -297,8 +298,6 @@ def push_checkpoint_to_hub(
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
The model repo is created idempotently.
"""
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
api.upload_folder(
+2 -6
View File
@@ -155,12 +155,10 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
import huggingface_hub
ckpt = tmp_path / "010000"
(ckpt / "pretrained_model").mkdir(parents=True)
api = MagicMock()
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
push_checkpoint_to_hub(ckpt, "user/run", private=True)
api.create_repo.assert_called_once()
assert api.create_repo.call_args.kwargs["private"] is True
@@ -175,12 +173,10 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
import huggingface_hub
ckpt = tmp_path / "010000"
(ckpt / "pretrained_model").mkdir(parents=True)
api = MagicMock()
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
push_checkpoint_to_hub(ckpt, "user/run")
api.create_repo.assert_called_once()
assert api.create_repo.call_args.kwargs["private"] is None