fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)

* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
Steven Palma
2026-03-27 22:21:55 +01:00
committed by GitHub
parent 975d89b38d
commit 4e45acca52
8 changed files with 440 additions and 40 deletions
+318
View File
@@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards,
property delegation, and the full create-record-finalize-read lifecycle.
"""
from pathlib import Path
from unittest.mock import Mock
import pytest
import torch
import lerobot.datasets.dataset_metadata as dataset_metadata_module
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.dataset_reader import DatasetReader
from lerobot.datasets.dataset_writer import DatasetWriter
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (2,), "names": None},
}
SNAPSHOT_MAIN_FEATURES = {
**SIMPLE_FEATURES,
"test": {"dtype": "float32", "shape": (2,), "names": None},
}
def _make_frame(task: str = "Dummy task") -> dict:
return {"task": task, "state": torch.randn(2)}
def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None:
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root)
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
def _write_dataset_tree(
root: Path,
*,
motor_features: dict[str, dict],
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
) -> None:
root.mkdir(parents=True, exist_ok=True)
info = info_factory(
total_episodes=1,
total_frames=3,
total_tasks=1,
use_videos=False,
motor_features=motor_features,
camera_features={},
)
tasks = tasks_factory(total_tasks=1)
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=1,
total_frames=3,
tasks=tasks,
)
stats = stats_factory(features=info["features"])
hf_dataset = hf_dataset_factory(
features=info["features"],
tasks=tasks,
episodes=episodes,
fps=info["fps"],
)
create_info(root, info)
create_stats(root, stats)
create_tasks(root, tasks)
create_episodes(root, episodes)
create_hf_dataset(root, hf_dataset)
# ── Read-only mode (via __init__) ────────────────────────────────────
@@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
assert len(dataset) == dataset.num_frames
def test_metadata_without_root_uses_hub_cache_snapshot_download(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True)
assert meta.root == snapshot_root
assert snapshot_download.call_count == 1
assert snapshot_download.call_args.args == (repo_id,)
assert snapshot_download.call_args.kwargs == {
"repo_type": "dataset",
"revision": "main",
"cache_dir": cache_root / "hub",
"allow_patterns": "meta/",
"ignore_patterns": None,
}
def test_without_root_reads_different_revisions_from_distinct_snapshot_roots(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Different revisions resolve to different on-disk snapshot roots."""
repo_id = DUMMY_REPO_ID
old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9"
cache_root = tmp_path / "lerobot_cache"
main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old"
_write_dataset_tree(
main_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_write_dataset_tree(
old_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_roots = {
"main": main_root,
old_revision: old_root,
}
meta_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
data_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
main_dataset = LeRobotDataset(
repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True
)
old_dataset = LeRobotDataset(
repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True
)
assert main_dataset.root == main_root
assert old_dataset.root == old_root
assert "test" in main_dataset.hf_dataset.column_names
assert "test" not in old_dataset.hf_dataset.column_names
# Metadata downloads use cache_dir, not local_dir
assert meta_snapshot_download.call_count == 2
for download_call in meta_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
# Data downloads also use cache_dir, not local_dir
assert data_snapshot_download.call_count == 2
for download_call in data_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
def test_metadata_without_root_ignores_legacy_local_dir_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
legacy_root = cache_root / repo_id
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
legacy_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
(legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True)
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main")
assert meta.root == snapshot_root
assert "test" in meta.features
assert snapshot_download.call_count == 1
def test_download_without_root_uses_hub_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
# Pre-populate snapshot directory so metadata loads succeed, but leave
# data absent so that _download() is triggered.
_write_dataset_tree(
snapshot_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
meta_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
# Mock the data snapshot_download to return the same root (data already
# exists there from _write_dataset_tree).
data_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True)
# _download() should have called snapshot_download with cache_dir
assert data_snapshot_download.call_count == 1
call_kwargs = data_snapshot_download.call_args.kwargs
assert call_kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in call_kwargs
# ── Write-only mode (via create()) ──────────────────────────────────