refactor(jobs): hoist LeRobotDataset import, guard dataset extra at package init

Move the `from lerobot.datasets import LeRobotDataset` import to the top of
dataset.py and relocate the `require_package("datasets", extra="dataset")`
guard to the jobs package __init__, per review feedback.
This commit is contained in:
Nicolas Rabault
2026-06-29 14:06:10 +02:00
parent 27e4ad5a3f
commit 016fc73bd1
3 changed files with 8 additions and 11 deletions
+6
View File
@@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.import_utils import require_package
# LeRobotDataset (imported at module top in dataset.py) pulls in heavy dataset deps;
# guard the optional dependency here so importing this package fails loudly if it's missing.
require_package("datasets", extra="dataset")
from .hf import submit_to_hf
__all__ = ["submit_to_hf"]
+1 -5
View File
@@ -23,8 +23,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING:
from huggingface_hub import HfApi
@@ -49,9 +49,5 @@ def ensure_dataset_available(repo_id: str, *, api: HfApi, tags: list[str] | None
)
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
require_package("datasets", extra="dataset")
from lerobot.datasets import LeRobotDataset
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
+1 -6
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import MagicMock
import pytest
@@ -47,11 +46,7 @@ def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
api = _api_with_dataset(False)
mock_ds_cls = MagicMock()
fake_datasets_module = MagicMock()
fake_datasets_module.LeRobotDataset = mock_ds_cls
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
# The `datasets` extra isn't installed in the base test env; skip the import guard.
monkeypatch.setattr("lerobot.jobs.dataset.require_package", lambda *a, **k: None)
monkeypatch.setattr("lerobot.jobs.dataset.LeRobotDataset", mock_ds_cls)
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None