diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 76d1371f3..90c5d6727 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging import shutil from pathlib import Path @@ -23,9 +24,11 @@ import datasets import pandas as pd import tqdm +from lerobot.configs import VIDEO_ENCODER_INFO_KEYS + from .compute_stats import aggregate_stats from .dataset_metadata import LeRobotDatasetMetadata -from .feature_utils import get_hf_features_from_features +from .feature_utils import features_equal_for_merge, get_hf_features_from_features from .io_utils import ( get_file_size_in_mb, get_parquet_file_size_in_mb, @@ -46,11 +49,54 @@ from .utils import ( from .video_utils import concatenate_video_files, get_video_duration_in_s +def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]: + """Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged. + + Args: + all_metadata: List of LeRobotDatasetMetadata objects to merge. + + Returns: + dict: A dictionary of merged video feature info. + """ + merged_info = copy.deepcopy(all_metadata[0].features) + video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"] + + for vk in video_keys: + video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata] + base_video_info = video_infos[0] + + merged_encoder_info: dict = {} + fallback_keys: list[str] = [] + for info_key in VIDEO_ENCODER_INFO_KEYS: + values = [info.get(info_key, None) for info in video_infos] + first_value = values[0] + all_match = all(v == first_value for v in values[1:]) + + if all_match: + merged_encoder_info[info_key] = first_value + else: + fallback_keys.append(info_key) + merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None + + if fallback_keys: + logging.warning( + f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. " + f"Setting these keys to null (video.extra_options -> {fallback_keys}).", + ) + + merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info} + # TODO(CarolinePascal): make this variable once we have support for other video backends. + merged_info[vk]["info"]["video.video_backend"] = "pyav" + + return merged_info + + def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): """Validates that all dataset metadata have consistent properties. Ensures all datasets have the same fps, robot_type, and features to guarantee compatibility when aggregating them into a single dataset. + Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``. Args: all_metadata: List of LeRobotDatasetMetadata objects to validate. @@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): raise ValueError( f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}." ) - if features != meta.features: + if not features_equal_for_merge(features, meta.features): raise ValueError( f"Same features is expected, but got features={meta.features} instead of {features}." ) @@ -274,7 +320,8 @@ def aggregate_datasets( LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False) ] ) - fps, robot_type, features = validate_all_metadata(all_metadata) + fps, robot_type, _ = validate_all_metadata(all_metadata) + features = merge_video_feature_info_for_aggregate(all_metadata) video_keys = [key for key in features if features[key]["dtype"] == "video"] dst_meta = LeRobotDatasetMetadata.create( @@ -413,6 +460,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu current_dst_duration = dst_file_durations.get(dst_key, 0) videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key + # TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails concatenate_video_files( [dst_path, src_path], dst_path, diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index 2ab4b0ea6..ce9208134 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -19,6 +19,7 @@ import datasets import numpy as np from PIL import Image as PILImage +from lerobot.configs import VIDEO_ENCODER_INFO_KEYS from lerobot.utils.constants import DEFAULT_FEATURES from lerobot.utils.utils import is_valid_numpy_dtype_string @@ -108,6 +109,42 @@ def create_empty_dataset_info( ) +def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool: + """Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation. + + For video features, keys under ``info`` related to video encoding parameters are ignored during + comparison as they do not prevent aggregation. + """ + + def _without_encoder_info_keys(feature: dict) -> dict: + filtered = dict(feature) + filtered_info = filtered.get("info") + if isinstance(filtered_info, dict): + filtered["info"] = { + info_key: info_value + for info_key, info_value in filtered_info.items() + if info_key not in VIDEO_ENCODER_INFO_KEYS + } + return filtered + + if set(features_a) != set(features_b): + return False + for key in features_a: + fa_key = features_a[key] + fb_key = features_b[key] + if fa_key.get("dtype") != fb_key.get("dtype"): + return False + if fa_key.get("dtype") != "video": + if fa_key != fb_key: + return False + continue + + if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key): + raise ValueError(f"Features {fa_key} and {fb_key} are not equal") + return False + return True + + def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 6d646d4f7..80a95aa1f 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import logging from unittest.mock import patch import pytest @@ -23,7 +25,9 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da import datasets # noqa: E402 import torch +from lerobot.configs import VIDEO_ENCODER_INFO_KEYS from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.feature_utils import features_equal_for_merge from lerobot.datasets.lerobot_dataset import LeRobotDataset from tests.fixtures.constants import DUMMY_REPO_ID @@ -117,8 +121,9 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1): "Robot type should be the same" ) - # Test features are the same - assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same" + # Schema matches; merged video ``info`` is reconciled separately from per-source ``info``. + assert features_equal_for_merge(aggr_ds.features, ds_0.features) + assert features_equal_for_merge(aggr_ds.features, ds_1.features) # Test tasks aggregation expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index) @@ -284,6 +289,73 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_dataset_iteration_works(aggr_ds) +@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"]) +def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders( + tmp_path, lerobot_dataset_factory, caplog, mutation +): + """Mismatched or missing encoder ``info`` is merged per-key with fallbacks and a warning.""" + suffix = "enc_mismatch" if mutation == "mismatched_value" else "enc_missing" + ds_0 = lerobot_dataset_factory( + root=tmp_path / f"{suffix}_a", + repo_id=f"{DUMMY_REPO_ID}_{suffix}_a", + total_episodes=2, + total_frames=20, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / f"{suffix}_b", + repo_id=f"{DUMMY_REPO_ID}_{suffix}_b", + total_episodes=2, + total_frames=20, + ) + + info_path = ds_1.root / "meta" / "info.json" + data = json.loads(info_path.read_text()) + for ft in data["features"].values(): + if ft.get("dtype") != "video": + continue + inf = ft.setdefault("info", {}) + if mutation == "mismatched_value": + inf["video.crf"] = 99 + inf["video.extra_options"] = {"tune": "film"} + else: + inf.pop("video.crf", None) + inf.pop("video.extra_options", None) + info_path.write_text(json.dumps(data)) + + aggr_id = f"{DUMMY_REPO_ID}_{suffix}_aggr" + aggr_root = tmp_path / f"{suffix}_aggr" + with caplog.at_level(logging.WARNING): + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=aggr_id, + aggr_root=aggr_root, + ) + + assert "heterogeneous" in caplog.text.lower() or "incomplete" in caplog.text.lower() + + with ( + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(aggr_root) + aggr_ds = LeRobotDataset(aggr_id, root=aggr_root) + + for key, ft in aggr_ds.meta.info.features.items(): + if ft.get("dtype") != "video": + continue + info = ft["info"] + reference = ds_0.meta.info.features[key]["info"] + for info_key in VIDEO_ENCODER_INFO_KEYS: + if info_key == "video.crf": + assert info[info_key] is None + elif info_key == "video.extra_options": + assert info[info_key] == {} + else: + assert info[info_key] == reference[info_key] + + def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): """Test aggregation with small file size limits to force file rotation/sharding.""" ds_0_num_episodes = ds_1_num_episodes = 10