From 016fc73bd1041f938620356e7138d9556995acac Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Mon, 29 Jun 2026 14:06:10 +0200 Subject: [PATCH] refactor(jobs): hoist LeRobotDataset import, guard dataset extra at package init Move the `from lerobot.datasets import LeRobotDataset` import to the top of dataset.py and relocate the `require_package("datasets", extra="dataset")` guard to the jobs package __init__, per review feedback. --- src/lerobot/jobs/__init__.py | 6 ++++++ src/lerobot/jobs/dataset.py | 6 +----- tests/jobs/test_dataset.py | 7 +------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lerobot/jobs/__init__.py b/src/lerobot/jobs/__init__.py index b13133752..674b98b85 100644 --- a/src/lerobot/jobs/__init__.py +++ b/src/lerobot/jobs/__init__.py @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.utils.import_utils import require_package + +# LeRobotDataset (imported at module top in dataset.py) pulls in heavy dataset deps; +# guard the optional dependency here so importing this package fails loudly if it's missing. +require_package("datasets", extra="dataset") + from .hf import submit_to_hf __all__ = ["submit_to_hf"] diff --git a/src/lerobot/jobs/dataset.py b/src/lerobot/jobs/dataset.py index cc01fd9fb..497f8445e 100644 --- a/src/lerobot/jobs/dataset.py +++ b/src/lerobot/jobs/dataset.py @@ -23,8 +23,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +from lerobot.datasets import LeRobotDataset from lerobot.utils.constants import HF_LEROBOT_HOME -from lerobot.utils.import_utils import require_package if TYPE_CHECKING: from huggingface_hub import HfApi @@ -49,9 +49,5 @@ def ensure_dataset_available(repo_id: str, *, api: HfApi, tags: list[str] | None ) print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...") - # Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed. - require_package("datasets", extra="dataset") - from lerobot.datasets import LeRobotDataset - LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags) print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.") diff --git a/tests/jobs/test_dataset.py b/tests/jobs/test_dataset.py index 417f2bb1c..d45ac48aa 100644 --- a/tests/jobs/test_dataset.py +++ b/tests/jobs/test_dataset.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import MagicMock import pytest @@ -47,11 +46,7 @@ def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch): api = _api_with_dataset(False) mock_ds_cls = MagicMock() - fake_datasets_module = MagicMock() - fake_datasets_module.LeRobotDataset = mock_ds_cls - monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module) - # The `datasets` extra isn't installed in the base test env; skip the import guard. - monkeypatch.setattr("lerobot.jobs.dataset.require_package", lambda *a, **k: None) + monkeypatch.setattr("lerobot.jobs.dataset.LeRobotDataset", mock_ds_cls) assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None