Added mock context manager to tests in order to avoid calls to the hub for dummy datasets

This commit is contained in:
Michel Aractingi
2025-07-30 12:11:16 +02:00
parent 527ae8e557
commit 1c79e3dec1
2 changed files with 36 additions and 4 deletions
+17 -2
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import torch
from unittest.mock import patch
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -207,7 +208,14 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
aggr_root=tmp_path / "test_aggr",
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Run all assertion functions
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
@@ -250,7 +258,14 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
video_files_size_in_mb=0.1,
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Verify aggregation worked correctly despite file size constraints
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes