mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
Added mock context manager to tests in order to avoid calls to the hub for dummy datasets
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -207,7 +208,14 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
aggr_root=tmp_path / "test_aggr",
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
|
||||
# Run all assertion functions
|
||||
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
||||
@@ -250,7 +258,14 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
video_files_size_in_mb=0.1,
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
||||
|
||||
# Verify aggregation worked correctly despite file size constraints
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
|
||||
Reference in New Issue
Block a user