mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)
* refactor(dataset): enhance dataset root directory handling and introduce hub cache support - Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads. - Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management. - Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory. * refactor(dataset): improve root directory handling in LeRobotDataset - Updated LeRobotDataset to store the requested root path separately from the actual root path. - Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management. * refactor(dataset): minor improvements for hub cache support * chore(datasets): guard in resume + assertion test --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
@@ -27,7 +27,8 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
|
||||
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -44,11 +44,12 @@ from lerobot.datasets.utils import (
|
||||
check_version_compatibility,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -77,8 +78,12 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
|
||||
root: Local directory for the dataset. Defaults to
|
||||
``$HF_LEROBOT_HOME/{repo_id}``.
|
||||
root: Local directory for the dataset. When provided, Hub downloads
|
||||
are materialized directly into this directory. When omitted,
|
||||
existing local datasets are still looked up under
|
||||
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
|
||||
revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
revision: Git revision (branch, tag, or commit hash). Defaults to
|
||||
the current codebase version.
|
||||
force_cache_sync: If ``True``, re-download metadata from the Hub
|
||||
@@ -88,7 +93,8 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self._requested_root = Path(root) if root is not None else None
|
||||
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self._pq_writer = None
|
||||
self.latest_episode = None
|
||||
self._metadata_buffer: list[dict] = []
|
||||
@@ -96,14 +102,15 @@ class LeRobotDatasetMetadata:
|
||||
self._finalized = False
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
if force_cache_sync or (
|
||||
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
|
||||
):
|
||||
raise FileNotFoundError
|
||||
self._load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self._pull_from_repo(allow_patterns="meta/")
|
||||
self._load_metadata()
|
||||
|
||||
@@ -178,14 +185,29 @@ class LeRobotDatasetMetadata:
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
if self._requested_root is None:
|
||||
self.root = Path(
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
cache_dir=HF_LEROBOT_HUB_CACHE,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
local_dir=self._requested_root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
self.root = self._requested_root
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
@@ -593,7 +615,8 @@ class LeRobotDatasetMetadata:
|
||||
"""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
obj._requested_root = Path(root) if root is not None else None
|
||||
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class DatasetReader:
|
||||
visual features.
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self.root = root
|
||||
self.episodes = episodes
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
@@ -125,7 +125,7 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -150,7 +150,7 @@ class DatasetReader:
|
||||
if len(self._meta.video_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for vid_key in self._meta.video_keys:
|
||||
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
@@ -240,7 +240,7 @@ class DatasetReader:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from lerobot.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -144,10 +144,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
root (Path | None, optional): Local directory where the dataset will be read from or downloaded
|
||||
into. If set, all dataset files are materialized directly under this path. If not set,
|
||||
existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub
|
||||
downloads use a revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -190,7 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
@@ -201,12 +202,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
if self._requested_root is not None:
|
||||
self._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
# Load metadata (sets self.root once from the resolved metadata root)
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
# Create reader (hf_dataset loaded below)
|
||||
self.reader = DatasetReader(
|
||||
@@ -556,14 +560,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.episodes is not None:
|
||||
# Reader is guaranteed to exist here (created in __init__ before _download)
|
||||
files = self.reader.get_episodes_file_paths()
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
if self._requested_root is None:
|
||||
self.meta.root = Path(
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
cache_dir=HF_LEROBOT_HUB_CACHE,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self._requested_root,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
self.meta.root = self._requested_root
|
||||
|
||||
# Propagate resolved root from metadata (single source of truth)
|
||||
self.root = self.meta.root
|
||||
self.reader.root = self.meta.root
|
||||
|
||||
# ── Class constructors ────────────────────────────────────────────
|
||||
|
||||
@@ -635,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj._requested_root = obj.meta.root
|
||||
obj.root = obj.meta.root
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
@@ -695,8 +719,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier of the existing dataset.
|
||||
root: Local directory of the dataset. Defaults to
|
||||
``$HF_LEROBOT_HOME/{repo_id}``.
|
||||
root: Local directory of the dataset. When provided, Hub downloads
|
||||
are materialized directly into this directory. When omitted,
|
||||
Hub downloads use a revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
tolerance_s: Timestamp synchronization tolerance in seconds.
|
||||
revision: Git revision (branch, tag, or commit hash). Defaults to
|
||||
current codebase version tag.
|
||||
@@ -716,11 +742,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
"""
|
||||
if not root:
|
||||
raise ValueError(
|
||||
"resume() requires an explicit 'root' directory because it creates a DatasetWriter. "
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
obj.root.mkdir(exist_ok=True, parents=True)
|
||||
obj._requested_root = Path(root)
|
||||
obj.revision = revision if revision else CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_transforms = None
|
||||
@@ -731,10 +762,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Load metadata
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata (revision-safe when root is not provided)
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
obj.root = obj.meta.root
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
@@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
|
||||
metadata is resolved through a revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list.
|
||||
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||
@@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.streaming_from_local = root is not None
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
@@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -101,6 +102,18 @@ DEFAULT_FEATURES = {
|
||||
}
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
``snapshot_download(local_dir=...)`` stores lightweight metadata under
|
||||
``<local_dir>/.cache/huggingface/download/``. The presence of this
|
||||
directory is a reliable indicator that the dataset was downloaded with
|
||||
the old non-revision-safe ``local_dir`` mode and should be re-fetched
|
||||
through the snapshot cache instead.
|
||||
"""
|
||||
return (root / ".cache" / "huggingface" / "download").exists()
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
|
||||
@@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ:
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
|
||||
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
|
||||
# dataset revisions are stored in isolated snapshot directories.
|
||||
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
|
||||
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||
|
||||
@@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards,
|
||||
property delegation, and the full create-record-finalize-read lifecycle.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import lerobot.datasets.dataset_metadata as dataset_metadata_module
|
||||
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.datasets.dataset_writer import DatasetWriter
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
SIMPLE_FEATURES = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
}
|
||||
SNAPSHOT_MAIN_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"test": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def _make_frame(task: str = "Dummy task") -> dict:
|
||||
return {"task": task, "state": torch.randn(2)}
|
||||
|
||||
|
||||
def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None:
|
||||
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root)
|
||||
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
|
||||
monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
|
||||
|
||||
|
||||
def _write_dataset_tree(
|
||||
root: Path,
|
||||
*,
|
||||
motor_features: dict[str, dict],
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
) -> None:
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
info = info_factory(
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
total_tasks=1,
|
||||
use_videos=False,
|
||||
motor_features=motor_features,
|
||||
camera_features={},
|
||||
)
|
||||
tasks = tasks_factory(total_tasks=1)
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
tasks=tasks,
|
||||
)
|
||||
stats = stats_factory(features=info["features"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"],
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
fps=info["fps"],
|
||||
)
|
||||
|
||||
create_info(root, info)
|
||||
create_stats(root, stats)
|
||||
create_tasks(root, tasks)
|
||||
create_episodes(root, episodes)
|
||||
create_hf_dataset(root, hf_dataset)
|
||||
|
||||
|
||||
# ── Read-only mode (via __init__) ────────────────────────────────────
|
||||
|
||||
|
||||
@@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
|
||||
assert len(dataset) == dataset.num_frames
|
||||
|
||||
|
||||
def test_metadata_without_root_uses_hub_cache_snapshot_download(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True)
|
||||
|
||||
assert meta.root == snapshot_root
|
||||
assert snapshot_download.call_count == 1
|
||||
assert snapshot_download.call_args.args == (repo_id,)
|
||||
assert snapshot_download.call_args.kwargs == {
|
||||
"repo_type": "dataset",
|
||||
"revision": "main",
|
||||
"cache_dir": cache_root / "hub",
|
||||
"allow_patterns": "meta/",
|
||||
"ignore_patterns": None,
|
||||
}
|
||||
|
||||
|
||||
def test_without_root_reads_different_revisions_from_distinct_snapshot_roots(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Different revisions resolve to different on-disk snapshot roots."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9"
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old"
|
||||
|
||||
_write_dataset_tree(
|
||||
main_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
_write_dataset_tree(
|
||||
old_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_roots = {
|
||||
"main": main_root,
|
||||
old_revision: old_root,
|
||||
}
|
||||
meta_snapshot_download = Mock(
|
||||
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
|
||||
)
|
||||
data_snapshot_download = Mock(
|
||||
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
|
||||
)
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
|
||||
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
|
||||
|
||||
main_dataset = LeRobotDataset(
|
||||
repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True
|
||||
)
|
||||
old_dataset = LeRobotDataset(
|
||||
repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True
|
||||
)
|
||||
|
||||
assert main_dataset.root == main_root
|
||||
assert old_dataset.root == old_root
|
||||
assert "test" in main_dataset.hf_dataset.column_names
|
||||
assert "test" not in old_dataset.hf_dataset.column_names
|
||||
|
||||
# Metadata downloads use cache_dir, not local_dir
|
||||
assert meta_snapshot_download.call_count == 2
|
||||
for download_call in meta_snapshot_download.call_args_list:
|
||||
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in download_call.kwargs
|
||||
|
||||
# Data downloads also use cache_dir, not local_dir
|
||||
assert data_snapshot_download.call_count == 2
|
||||
for download_call in data_snapshot_download.call_args_list:
|
||||
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in download_call.kwargs
|
||||
|
||||
|
||||
def test_metadata_without_root_ignores_legacy_local_dir_cache(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
legacy_root = cache_root / repo_id
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
|
||||
_write_dataset_tree(
|
||||
legacy_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
(legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True)
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main")
|
||||
|
||||
assert meta.root == snapshot_root
|
||||
assert "test" in meta.features
|
||||
assert snapshot_download.call_count == 1
|
||||
|
||||
|
||||
def test_download_without_root_uses_hub_cache(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
|
||||
# Pre-populate snapshot directory so metadata loads succeed, but leave
|
||||
# data absent so that _download() is triggered.
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
meta_snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
|
||||
|
||||
# Mock the data snapshot_download to return the same root (data already
|
||||
# exists there from _write_dataset_tree).
|
||||
data_snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
|
||||
|
||||
LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True)
|
||||
|
||||
# _download() should have called snapshot_download with cache_dir
|
||||
assert data_snapshot_download.call_count == 1
|
||||
call_kwargs = data_snapshot_download.call_args.kwargs
|
||||
assert call_kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in call_kwargs
|
||||
|
||||
|
||||
# ── Write-only mode (via create()) ──────────────────────────────────
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user