refactor(dataset): modular files (#3171)

* refactor(dataset): modular files

* refactor(dataset): update imports across the codebase
This commit is contained in:
Steven Palma
2026-03-15 23:58:09 -07:00
committed by GitHub
parent 9d3b62aa61
commit d90e4bcfd3
53 changed files with 2030 additions and 1901 deletions
+12 -12
View File
@@ -260,8 +260,8 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
@@ -311,8 +311,8 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
@@ -367,8 +367,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
@@ -492,8 +492,8 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
# Load the aggregated dataset
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
@@ -562,8 +562,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
@@ -590,8 +590,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")