mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)
* refactor(dataset): enhance dataset root directory handling and introduce hub cache support - Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads. - Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management. - Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory. * refactor(dataset): improve root directory handling in LeRobotDataset - Updated LeRobotDataset to store the requested root path separately from the actual root path. - Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management. * refactor(dataset): minor improvements for hub cache support * chore(datasets): guard in resume + assertion test --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
@@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards,
|
||||
property delegation, and the full create-record-finalize-read lifecycle.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import lerobot.datasets.dataset_metadata as dataset_metadata_module
|
||||
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.datasets.dataset_writer import DatasetWriter
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
SIMPLE_FEATURES = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
}
|
||||
SNAPSHOT_MAIN_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"test": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def _make_frame(task: str = "Dummy task") -> dict:
|
||||
return {"task": task, "state": torch.randn(2)}
|
||||
|
||||
|
||||
def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None:
|
||||
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root)
|
||||
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
|
||||
monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
|
||||
|
||||
|
||||
def _write_dataset_tree(
|
||||
root: Path,
|
||||
*,
|
||||
motor_features: dict[str, dict],
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
) -> None:
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
info = info_factory(
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
total_tasks=1,
|
||||
use_videos=False,
|
||||
motor_features=motor_features,
|
||||
camera_features={},
|
||||
)
|
||||
tasks = tasks_factory(total_tasks=1)
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
tasks=tasks,
|
||||
)
|
||||
stats = stats_factory(features=info["features"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"],
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
fps=info["fps"],
|
||||
)
|
||||
|
||||
create_info(root, info)
|
||||
create_stats(root, stats)
|
||||
create_tasks(root, tasks)
|
||||
create_episodes(root, episodes)
|
||||
create_hf_dataset(root, hf_dataset)
|
||||
|
||||
|
||||
# ── Read-only mode (via __init__) ────────────────────────────────────
|
||||
|
||||
|
||||
@@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
|
||||
assert len(dataset) == dataset.num_frames
|
||||
|
||||
|
||||
def test_metadata_without_root_uses_hub_cache_snapshot_download(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True)
|
||||
|
||||
assert meta.root == snapshot_root
|
||||
assert snapshot_download.call_count == 1
|
||||
assert snapshot_download.call_args.args == (repo_id,)
|
||||
assert snapshot_download.call_args.kwargs == {
|
||||
"repo_type": "dataset",
|
||||
"revision": "main",
|
||||
"cache_dir": cache_root / "hub",
|
||||
"allow_patterns": "meta/",
|
||||
"ignore_patterns": None,
|
||||
}
|
||||
|
||||
|
||||
def test_without_root_reads_different_revisions_from_distinct_snapshot_roots(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Different revisions resolve to different on-disk snapshot roots."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9"
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old"
|
||||
|
||||
_write_dataset_tree(
|
||||
main_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
_write_dataset_tree(
|
||||
old_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_roots = {
|
||||
"main": main_root,
|
||||
old_revision: old_root,
|
||||
}
|
||||
meta_snapshot_download = Mock(
|
||||
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
|
||||
)
|
||||
data_snapshot_download = Mock(
|
||||
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
|
||||
)
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
|
||||
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
|
||||
|
||||
main_dataset = LeRobotDataset(
|
||||
repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True
|
||||
)
|
||||
old_dataset = LeRobotDataset(
|
||||
repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True
|
||||
)
|
||||
|
||||
assert main_dataset.root == main_root
|
||||
assert old_dataset.root == old_root
|
||||
assert "test" in main_dataset.hf_dataset.column_names
|
||||
assert "test" not in old_dataset.hf_dataset.column_names
|
||||
|
||||
# Metadata downloads use cache_dir, not local_dir
|
||||
assert meta_snapshot_download.call_count == 2
|
||||
for download_call in meta_snapshot_download.call_args_list:
|
||||
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in download_call.kwargs
|
||||
|
||||
# Data downloads also use cache_dir, not local_dir
|
||||
assert data_snapshot_download.call_count == 2
|
||||
for download_call in data_snapshot_download.call_args_list:
|
||||
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in download_call.kwargs
|
||||
|
||||
|
||||
def test_metadata_without_root_ignores_legacy_local_dir_cache(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
legacy_root = cache_root / repo_id
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
|
||||
_write_dataset_tree(
|
||||
legacy_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
(legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True)
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SNAPSHOT_MAIN_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main")
|
||||
|
||||
assert meta.root == snapshot_root
|
||||
assert "test" in meta.features
|
||||
assert snapshot_download.call_count == 1
|
||||
|
||||
|
||||
def test_download_without_root_uses_hub_cache(
|
||||
tmp_path,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
create_info,
|
||||
create_stats,
|
||||
create_tasks,
|
||||
create_episodes,
|
||||
create_hf_dataset,
|
||||
monkeypatch,
|
||||
):
|
||||
"""LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided."""
|
||||
repo_id = DUMMY_REPO_ID
|
||||
cache_root = tmp_path / "lerobot_cache"
|
||||
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
|
||||
|
||||
# Pre-populate snapshot directory so metadata loads succeed, but leave
|
||||
# data absent so that _download() is triggered.
|
||||
_write_dataset_tree(
|
||||
snapshot_root,
|
||||
motor_features=SIMPLE_FEATURES,
|
||||
info_factory=info_factory,
|
||||
stats_factory=stats_factory,
|
||||
tasks_factory=tasks_factory,
|
||||
episodes_factory=episodes_factory,
|
||||
hf_dataset_factory=hf_dataset_factory,
|
||||
create_info=create_info,
|
||||
create_stats=create_stats,
|
||||
create_tasks=create_tasks,
|
||||
create_episodes=create_episodes,
|
||||
create_hf_dataset=create_hf_dataset,
|
||||
)
|
||||
|
||||
_set_default_cache_root(monkeypatch, cache_root)
|
||||
meta_snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
|
||||
|
||||
# Mock the data snapshot_download to return the same root (data already
|
||||
# exists there from _write_dataset_tree).
|
||||
data_snapshot_download = Mock(return_value=str(snapshot_root))
|
||||
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
|
||||
|
||||
LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True)
|
||||
|
||||
# _download() should have called snapshot_download with cache_dir
|
||||
assert data_snapshot_download.call_count == 1
|
||||
call_kwargs = data_snapshot_download.call_args.kwargs
|
||||
assert call_kwargs["cache_dir"] == cache_root / "hub"
|
||||
assert "local_dir" not in call_kwargs
|
||||
|
||||
|
||||
# ── Write-only mode (via create()) ──────────────────────────────────
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user