fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)

* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
Steven Palma
2026-03-27 22:21:55 +01:00
committed by GitHub
parent 975d89b38d
commit 4e45acca52
8 changed files with 440 additions and 40 deletions
+2 -1
View File
@@ -27,7 +27,8 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+31 -8
View File
@@ -44,11 +44,12 @@ from lerobot.datasets.utils import (
check_version_compatibility,
flatten_dict,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
CODEBASE_VERSION = "v3.0"
@@ -77,8 +78,12 @@ class LeRobotDatasetMetadata:
Args:
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
root: Local directory for the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``.
root: Local directory for the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
existing local datasets are still looked up under
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
revision: Git revision (branch, tag, or commit hash). Defaults to
the current codebase version.
force_cache_sync: If ``True``, re-download metadata from the Hub
@@ -88,7 +93,8 @@ class LeRobotDatasetMetadata:
"""
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root is not None else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self._pq_writer = None
self.latest_episode = None
self._metadata_buffer: list[dict] = []
@@ -96,14 +102,15 @@ class LeRobotDatasetMetadata:
self._finalized = False
try:
if force_cache_sync:
if force_cache_sync or (
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
):
raise FileNotFoundError
self._load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self._pull_from_repo(allow_patterns="meta/")
self._load_metadata()
@@ -178,14 +185,29 @@ class LeRobotDatasetMetadata:
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
if self._requested_root is None:
self.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
return
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
local_dir=self._requested_root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
self.root = self._requested_root
@property
def url_root(self) -> str:
@@ -593,7 +615,8 @@ class LeRobotDatasetMetadata:
"""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj._requested_root = Path(root) if root is not None else None
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
+4 -4
View File
@@ -68,7 +68,7 @@ class DatasetReader:
visual features.
"""
self._meta = meta
self._root = root
self.root = root
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
@@ -125,7 +125,7 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -150,7 +150,7 @@ class DatasetReader:
if len(self._meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self._meta.video_keys:
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
@@ -240,7 +240,7 @@ class DatasetReader:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)
+58 -23
View File
@@ -37,7 +37,7 @@ from lerobot.datasets.video_utils import (
get_safe_default_codec,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
logger = logging.getLogger(__name__)
@@ -144,10 +144,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
root (Path | None, optional): Local directory where the dataset will be read from or downloaded
into. If set, all dataset files are materialized directly under this path. If not set,
existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub
downloads use a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -190,7 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
@@ -201,12 +202,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self._requested_root.mkdir(exist_ok=True, parents=True)
# Load metadata
# Load metadata (sets self.root once from the resolved metadata root)
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
@@ -556,14 +560,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.episodes is not None:
# Reader is guaranteed to exist here (created in __init__ before _download)
files = self.reader.get_episodes_file_paths()
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
if self._requested_root is None:
self.meta.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
)
else:
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self._requested_root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
self.meta.root = self._requested_root
# Propagate resolved root from metadata (single source of truth)
self.root = self.meta.root
self.reader.root = self.meta.root
# ── Class constructors ────────────────────────────────────────────
@@ -635,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
metadata_buffer_size=metadata_buffer_size,
)
obj.repo_id = obj.meta.repo_id
obj._requested_root = obj.meta.root
obj.root = obj.meta.root
obj.revision = None
obj.tolerance_s = tolerance_s
@@ -695,8 +719,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
Args:
repo_id: Repository identifier of the existing dataset.
root: Local directory of the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``.
root: Local directory of the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
Hub downloads use a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
tolerance_s: Timestamp synchronization tolerance in seconds.
revision: Git revision (branch, tag, or commit hash). Defaults to
current codebase version tag.
@@ -716,11 +742,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes.
"""
if not root:
raise ValueError(
"resume() requires an explicit 'root' directory because it creates a DatasetWriter. "
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path."
)
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(exist_ok=True, parents=True)
obj._requested_root = Path(root)
obj.revision = revision if revision else CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_transforms = None
@@ -731,10 +762,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
# Load metadata
if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True)
# Load metadata (revision-safe when root is not provided)
obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
)
obj.root = obj.meta.root
# Reader is lazily created on first access (write-only mode)
obj.reader = None
+10 -4
View File
@@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
metadata is resolved through a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
@@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self.streaming_from_local = root is not None
self.image_transforms = image_transforms
@@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
+13
View File
@@ -18,6 +18,7 @@ import importlib.resources
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import datasets
@@ -101,6 +102,18 @@ DEFAULT_FEATURES = {
}
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
``snapshot_download(local_dir=...)`` stores lightweight metadata under
``<local_dir>/.cache/huggingface/download/``. The presence of this
directory is a reliable indicator that the dataset was downloaded with
the old non-revision-safe ``local_dir`` mode and should be re-fetched
through the snapshot cache instead.
"""
return (root / ".cache" / "huggingface" / "download").exists()
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0
+4
View File
@@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ:
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
# dataset revisions are stored in isolated snapshot directories.
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"
+318
View File
@@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards,
property delegation, and the full create-record-finalize-read lifecycle.
"""
from pathlib import Path
from unittest.mock import Mock
import pytest
import torch
import lerobot.datasets.dataset_metadata as dataset_metadata_module
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.dataset_reader import DatasetReader
from lerobot.datasets.dataset_writer import DatasetWriter
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (2,), "names": None},
}
SNAPSHOT_MAIN_FEATURES = {
**SIMPLE_FEATURES,
"test": {"dtype": "float32", "shape": (2,), "names": None},
}
def _make_frame(task: str = "Dummy task") -> dict:
return {"task": task, "state": torch.randn(2)}
def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None:
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root)
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
def _write_dataset_tree(
root: Path,
*,
motor_features: dict[str, dict],
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
) -> None:
root.mkdir(parents=True, exist_ok=True)
info = info_factory(
total_episodes=1,
total_frames=3,
total_tasks=1,
use_videos=False,
motor_features=motor_features,
camera_features={},
)
tasks = tasks_factory(total_tasks=1)
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=1,
total_frames=3,
tasks=tasks,
)
stats = stats_factory(features=info["features"])
hf_dataset = hf_dataset_factory(
features=info["features"],
tasks=tasks,
episodes=episodes,
fps=info["fps"],
)
create_info(root, info)
create_stats(root, stats)
create_tasks(root, tasks)
create_episodes(root, episodes)
create_hf_dataset(root, hf_dataset)
# ── Read-only mode (via __init__) ────────────────────────────────────
@@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
assert len(dataset) == dataset.num_frames
def test_metadata_without_root_uses_hub_cache_snapshot_download(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True)
assert meta.root == snapshot_root
assert snapshot_download.call_count == 1
assert snapshot_download.call_args.args == (repo_id,)
assert snapshot_download.call_args.kwargs == {
"repo_type": "dataset",
"revision": "main",
"cache_dir": cache_root / "hub",
"allow_patterns": "meta/",
"ignore_patterns": None,
}
def test_without_root_reads_different_revisions_from_distinct_snapshot_roots(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Different revisions resolve to different on-disk snapshot roots."""
repo_id = DUMMY_REPO_ID
old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9"
cache_root = tmp_path / "lerobot_cache"
main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old"
_write_dataset_tree(
main_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_write_dataset_tree(
old_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_roots = {
"main": main_root,
old_revision: old_root,
}
meta_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
data_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
main_dataset = LeRobotDataset(
repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True
)
old_dataset = LeRobotDataset(
repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True
)
assert main_dataset.root == main_root
assert old_dataset.root == old_root
assert "test" in main_dataset.hf_dataset.column_names
assert "test" not in old_dataset.hf_dataset.column_names
# Metadata downloads use cache_dir, not local_dir
assert meta_snapshot_download.call_count == 2
for download_call in meta_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
# Data downloads also use cache_dir, not local_dir
assert data_snapshot_download.call_count == 2
for download_call in data_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
def test_metadata_without_root_ignores_legacy_local_dir_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
legacy_root = cache_root / repo_id
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
legacy_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
(legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True)
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main")
assert meta.root == snapshot_root
assert "test" in meta.features
assert snapshot_download.call_count == 1
def test_download_without_root_uses_hub_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
# Pre-populate snapshot directory so metadata loads succeed, but leave
# data absent so that _download() is triggered.
_write_dataset_tree(
snapshot_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
meta_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
# Mock the data snapshot_download to return the same root (data already
# exists there from _write_dataset_tree).
data_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True)
# _download() should have called snapshot_download with cache_dir
assert data_snapshot_download.call_count == 1
call_kwargs = data_snapshot_download.call_args.kwargs
assert call_kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in call_kwargs
# ── Write-only mode (via create()) ──────────────────────────────────