diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 7b65b234b..f1f441e61 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -15,6 +15,7 @@ # limitations under the License. import torch +from unittest.mock import patch from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -207,7 +208,14 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): aggr_root=tmp_path / "test_aggr", ) - aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") + # Mock the revision to prevent Hub calls during dataset loading + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "test_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") # Run all assertion functions expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes @@ -250,7 +258,14 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): video_files_size_in_mb=0.1, ) - aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr") + # Mock the revision to prevent Hub calls during dataset loading + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "small_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr") # Verify aggregation worked correctly despite file size constraints expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index e45688c14..8a4663a49 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay @@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path): assert dataset.meta.total_tasks == 1 cfg.resume = True - dataset = record(cfg) + # Mock the revision to prevent Hub calls during resume + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "record") + dataset = record(cfg) assert dataset.meta.total_episodes == dataset.num_episodes == 2 assert dataset.meta.total_frames == dataset.num_frames == 6 @@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path): ) record(record_cfg) - replay(replay_cfg) + + # Mock the revision to prevent Hub calls during replay + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "record_and_replay") + replay(replay_cfg)