From dec1c6ddc20c417ea2e28bbc6f74d135c353ad35 Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Wed, 24 Jun 2026 11:03:04 +0200 Subject: [PATCH] refactor(train): use module-level HfApi import in push_checkpoint_to_hub huggingface_hub is a core dependency; the in-function import was unnecessary. Move HfApi to a module-level import and point the test monkeypatches at lerobot.common.train_utils.HfApi. --- src/lerobot/common/train_utils.py | 3 +-- tests/utils/test_train_utils.py | 8 ++------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 97488774e..5b5646ed3 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -15,6 +15,7 @@ # limitations under the License. from pathlib import Path +from huggingface_hub import HfApi from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -297,8 +298,6 @@ def push_checkpoint_to_hub( timed-out or crashed run still leaves recoverable checkpoints on the Hub. The model repo is created idempotently. """ - from huggingface_hub import HfApi - api = HfApi() api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) api.upload_folder( diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index be3f19231..2d8fc9e6c 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -155,12 +155,10 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler): def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch): - import huggingface_hub - ckpt = tmp_path / "010000" (ckpt / "pretrained_model").mkdir(parents=True) api = MagicMock() - monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api) + monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api) push_checkpoint_to_hub(ckpt, "user/run", private=True) api.create_repo.assert_called_once() assert api.create_repo.call_args.kwargs["private"] is True @@ -175,12 +173,10 @@ def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch): def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch): - import huggingface_hub - ckpt = tmp_path / "010000" (ckpt / "pretrained_model").mkdir(parents=True) api = MagicMock() - monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api) + monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api) push_checkpoint_to_hub(ckpt, "user/run") api.create_repo.assert_called_once() assert api.create_repo.call_args.kwargs["private"] is None