mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
feat(aggregate): updating dataset aggregation procedure. Encoding tuning paramters (crf, g,...) are ignored for validation and changed to None in the aggregated dataset if incompatible.
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +25,9 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
||||
import datasets # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
@@ -117,8 +121,9 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
# Test features are the same
|
||||
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
||||
# Schema matches; merged video ``info`` is reconciled separately from per-source ``info``.
|
||||
assert features_equal_for_merge(aggr_ds.features, ds_0.features)
|
||||
assert features_equal_for_merge(aggr_ds.features, ds_1.features)
|
||||
|
||||
# Test tasks aggregation
|
||||
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
||||
@@ -284,6 +289,73 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
|
||||
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
tmp_path, lerobot_dataset_factory, caplog, mutation
|
||||
):
|
||||
"""Mismatched or missing encoder ``info`` is merged per-key with fallbacks and a warning."""
|
||||
suffix = "enc_mismatch" if mutation == "mismatched_value" else "enc_missing"
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / f"{suffix}_a",
|
||||
repo_id=f"{DUMMY_REPO_ID}_{suffix}_a",
|
||||
total_episodes=2,
|
||||
total_frames=20,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / f"{suffix}_b",
|
||||
repo_id=f"{DUMMY_REPO_ID}_{suffix}_b",
|
||||
total_episodes=2,
|
||||
total_frames=20,
|
||||
)
|
||||
|
||||
info_path = ds_1.root / "meta" / "info.json"
|
||||
data = json.loads(info_path.read_text())
|
||||
for ft in data["features"].values():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
inf = ft.setdefault("info", {})
|
||||
if mutation == "mismatched_value":
|
||||
inf["video.crf"] = 99
|
||||
inf["video.extra_options"] = {"tune": "film"}
|
||||
else:
|
||||
inf.pop("video.crf", None)
|
||||
inf.pop("video.extra_options", None)
|
||||
info_path.write_text(json.dumps(data))
|
||||
|
||||
aggr_id = f"{DUMMY_REPO_ID}_{suffix}_aggr"
|
||||
aggr_root = tmp_path / f"{suffix}_aggr"
|
||||
with caplog.at_level(logging.WARNING):
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=aggr_id,
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
assert "heterogeneous" in caplog.text.lower() or "incomplete" in caplog.text.lower()
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(aggr_root)
|
||||
aggr_ds = LeRobotDataset(aggr_id, root=aggr_root)
|
||||
|
||||
for key, ft in aggr_ds.meta.info.features.items():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
info = ft["info"]
|
||||
reference = ds_0.meta.info.features[key]["info"]
|
||||
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||
if info_key == "video.crf":
|
||||
assert info[info_key] is None
|
||||
elif info_key == "video.extra_options":
|
||||
assert info[info_key] == {}
|
||||
else:
|
||||
assert info[info_key] == reference[info_key]
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
|
||||
Reference in New Issue
Block a user