mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Added mock context manager to tests in order to avoid calls to the hub for dummy datasets
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
@@ -207,7 +208,14 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
|||||||
aggr_root=tmp_path / "test_aggr",
|
aggr_root=tmp_path / "test_aggr",
|
||||||
)
|
)
|
||||||
|
|
||||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
# Mock the revision to prevent Hub calls during dataset loading
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||||
|
|
||||||
# Run all assertion functions
|
# Run all assertion functions
|
||||||
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
||||||
@@ -250,7 +258,14 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
|||||||
video_files_size_in_mb=0.1,
|
video_files_size_in_mb=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
# Mock the revision to prevent Hub calls during dataset loading
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
||||||
|
|
||||||
# Verify aggregation worked correctly despite file size constraints
|
# Verify aggregation worked correctly despite file size constraints
|
||||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from lerobot.calibrate import CalibrateConfig, calibrate
|
from lerobot.calibrate import CalibrateConfig, calibrate
|
||||||
from lerobot.record import DatasetRecordConfig, RecordConfig, record
|
from lerobot.record import DatasetRecordConfig, RecordConfig, record
|
||||||
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
|
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
|
||||||
@@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path):
|
|||||||
assert dataset.meta.total_tasks == 1
|
assert dataset.meta.total_tasks == 1
|
||||||
|
|
||||||
cfg.resume = True
|
cfg.resume = True
|
||||||
dataset = record(cfg)
|
# Mock the revision to prevent Hub calls during resume
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "record")
|
||||||
|
dataset = record(cfg)
|
||||||
|
|
||||||
assert dataset.meta.total_episodes == dataset.num_episodes == 2
|
assert dataset.meta.total_episodes == dataset.num_episodes == 2
|
||||||
assert dataset.meta.total_frames == dataset.num_frames == 6
|
assert dataset.meta.total_frames == dataset.num_frames == 6
|
||||||
@@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
record(record_cfg)
|
record(record_cfg)
|
||||||
replay(replay_cfg)
|
|
||||||
|
# Mock the revision to prevent Hub calls during replay
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
|
||||||
|
replay(replay_cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user