Added mock context manager to tests in order to avoid calls to the hub for dummy datasets

This commit is contained in:
Michel Aractingi
2025-07-30 12:11:16 +02:00
parent 527ae8e557
commit 1c79e3dec1
2 changed files with 36 additions and 4 deletions
+19 -2
View File
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from lerobot.calibrate import CalibrateConfig, calibrate
from lerobot.record import DatasetRecordConfig, RecordConfig, record
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
@@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path):
assert dataset.meta.total_tasks == 1
cfg.resume = True
dataset = record(cfg)
# Mock the revision to prevent Hub calls during resume
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record")
dataset = record(cfg)
assert dataset.meta.total_episodes == dataset.num_episodes == 2
assert dataset.meta.total_frames == dataset.num_frames == 6
@@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path):
)
record(record_cfg)
replay(replay_cfg)
# Mock the revision to prevent Hub calls during replay
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
replay(replay_cfg)