mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 14:47:10 +00:00
refactor(jobs): hoist LeRobotDataset import, guard dataset extra at package init
Move the `from lerobot.datasets import LeRobotDataset` import to the top of
dataset.py and relocate the `require_package("datasets", extra="dataset")`
guard to the jobs package __init__, per review feedback.
This commit is contained in:
@@ -12,6 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
# LeRobotDataset (imported at module top in dataset.py) pulls in heavy dataset deps;
|
||||
# guard the optional dependency here so importing this package fails loudly if it's missing.
|
||||
require_package("datasets", extra="dataset")
|
||||
|
||||
from .hf import submit_to_hf
|
||||
|
||||
__all__ = ["submit_to_hf"]
|
||||
|
||||
@@ -23,8 +23,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from huggingface_hub import HfApi
|
||||
@@ -49,9 +49,5 @@ def ensure_dataset_available(repo_id: str, *, api: HfApi, tags: list[str] | None
|
||||
)
|
||||
|
||||
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
|
||||
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
|
||||
require_package("datasets", extra="dataset")
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
|
||||
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -47,11 +46,7 @@ def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
mock_ds_cls = MagicMock()
|
||||
fake_datasets_module = MagicMock()
|
||||
fake_datasets_module.LeRobotDataset = mock_ds_cls
|
||||
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
|
||||
# The `datasets` extra isn't installed in the base test env; skip the import guard.
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.require_package", lambda *a, **k: None)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.LeRobotDataset", mock_ds_cls)
|
||||
|
||||
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user