Added mock context manager to tests in order to avoid calls to the hub for dummy datasets

This commit is contained in:
Michel Aractingi
2025-07-30 12:11:16 +02:00
parent 527ae8e557
commit 1c79e3dec1
2 changed files with 36 additions and 4 deletions
+17 -2
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
from unittest.mock import patch
from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -207,7 +208,14 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
aggr_root=tmp_path / "test_aggr", aggr_root=tmp_path / "test_aggr",
) )
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") # Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Run all assertion functions # Run all assertion functions
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
@@ -250,7 +258,14 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
video_files_size_in_mb=0.1, video_files_size_in_mb=0.1,
) )
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr") # Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Verify aggregation worked correctly despite file size constraints # Verify aggregation worked correctly despite file size constraints
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
+19 -2
View File
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import patch
from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.calibrate import CalibrateConfig, calibrate
from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.record import DatasetRecordConfig, RecordConfig, record
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
@@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path):
assert dataset.meta.total_tasks == 1 assert dataset.meta.total_tasks == 1
cfg.resume = True cfg.resume = True
dataset = record(cfg) # Mock the revision to prevent Hub calls during resume
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record")
dataset = record(cfg)
assert dataset.meta.total_episodes == dataset.num_episodes == 2 assert dataset.meta.total_episodes == dataset.num_episodes == 2
assert dataset.meta.total_frames == dataset.num_frames == 6 assert dataset.meta.total_frames == dataset.num_frames == 6
@@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path):
) )
record(record_cfg) record(record_cfg)
replay(replay_cfg)
# Mock the revision to prevent Hub calls during replay
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
replay(replay_cfg)