fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)

* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
Steven Palma
2026-03-27 22:21:55 +01:00
committed by GitHub
parent 975d89b38d
commit 4e45acca52
8 changed files with 440 additions and 40 deletions
+2 -1
View File
@@ -27,7 +27,8 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+31 -8
View File
@@ -44,11 +44,12 @@ from lerobot.datasets.utils import (
check_version_compatibility,
flatten_dict,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
CODEBASE_VERSION = "v3.0"
@@ -77,8 +78,12 @@ class LeRobotDatasetMetadata:
Args:
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
root: Local directory for the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``.
root: Local directory for the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
existing local datasets are still looked up under
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
revision: Git revision (branch, tag, or commit hash). Defaults to
the current codebase version.
force_cache_sync: If ``True``, re-download metadata from the Hub
@@ -88,7 +93,8 @@ class LeRobotDatasetMetadata:
"""
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root is not None else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self._pq_writer = None
self.latest_episode = None
self._metadata_buffer: list[dict] = []
@@ -96,14 +102,15 @@ class LeRobotDatasetMetadata:
self._finalized = False
try:
if force_cache_sync:
if force_cache_sync or (
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
):
raise FileNotFoundError
self._load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self._pull_from_repo(allow_patterns="meta/")
self._load_metadata()
@@ -178,14 +185,29 @@ class LeRobotDatasetMetadata:
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
if self._requested_root is None:
self.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
return
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
local_dir=self._requested_root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
self.root = self._requested_root
@property
def url_root(self) -> str:
@@ -593,7 +615,8 @@ class LeRobotDatasetMetadata:
"""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj._requested_root = Path(root) if root is not None else None
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
+4 -4
View File
@@ -68,7 +68,7 @@ class DatasetReader:
visual features.
"""
self._meta = meta
self._root = root
self.root = root
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
@@ -125,7 +125,7 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -150,7 +150,7 @@ class DatasetReader:
if len(self._meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self._meta.video_keys:
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
@@ -240,7 +240,7 @@ class DatasetReader:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)
+58 -23
View File
@@ -37,7 +37,7 @@ from lerobot.datasets.video_utils import (
get_safe_default_codec,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
logger = logging.getLogger(__name__)
@@ -144,10 +144,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
root (Path | None, optional): Local directory where the dataset will be read from or downloaded
into. If set, all dataset files are materialized directly under this path. If not set,
existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub
downloads use a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -190,7 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
@@ -201,12 +202,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self._requested_root.mkdir(exist_ok=True, parents=True)
# Load metadata
# Load metadata (sets self.root once from the resolved metadata root)
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
@@ -556,14 +560,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.episodes is not None:
# Reader is guaranteed to exist here (created in __init__ before _download)
files = self.reader.get_episodes_file_paths()
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
if self._requested_root is None:
self.meta.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
)
else:
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self._requested_root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
self.meta.root = self._requested_root
# Propagate resolved root from metadata (single source of truth)
self.root = self.meta.root
self.reader.root = self.meta.root
# ── Class constructors ────────────────────────────────────────────
@@ -635,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
metadata_buffer_size=metadata_buffer_size,
)
obj.repo_id = obj.meta.repo_id
obj._requested_root = obj.meta.root
obj.root = obj.meta.root
obj.revision = None
obj.tolerance_s = tolerance_s
@@ -695,8 +719,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
Args:
repo_id: Repository identifier of the existing dataset.
root: Local directory of the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``.
root: Local directory of the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
Hub downloads use a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
tolerance_s: Timestamp synchronization tolerance in seconds.
revision: Git revision (branch, tag, or commit hash). Defaults to
current codebase version tag.
@@ -716,11 +742,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes.
"""
if not root:
raise ValueError(
"resume() requires an explicit 'root' directory because it creates a DatasetWriter. "
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path."
)
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(exist_ok=True, parents=True)
obj._requested_root = Path(root)
obj.revision = revision if revision else CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_transforms = None
@@ -731,10 +762,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
# Load metadata
if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True)
# Load metadata (revision-safe when root is not provided)
obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
)
obj.root = obj.meta.root
# Reader is lazily created on first access (write-only mode)
obj.reader = None
+10 -4
View File
@@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
metadata is resolved through a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
@@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self.streaming_from_local = root is not None
self.image_transforms = image_transforms
@@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
+13
View File
@@ -18,6 +18,7 @@ import importlib.resources
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import datasets
@@ -101,6 +102,18 @@ DEFAULT_FEATURES = {
}
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
``snapshot_download(local_dir=...)`` stores lightweight metadata under
``<local_dir>/.cache/huggingface/download/``. The presence of this
directory is a reliable indicator that the dataset was downloaded with
the old non-revision-safe ``local_dir`` mode and should be re-fetched
through the snapshot cache instead.
"""
return (root / ".cache" / "huggingface" / "download").exists()
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0
+4
View File
@@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ:
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
# dataset revisions are stored in isolated snapshot directories.
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"