mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(aggregate): updating dataset aggregation procedure. Encoding tuning paramters (crf, g,...) are ignored for validation and changed to None in the aggregated dataset if incompatible.
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,9 +24,11 @@ import datasets
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||||
|
|
||||||
from .compute_stats import aggregate_stats
|
from .compute_stats import aggregate_stats
|
||||||
from .dataset_metadata import LeRobotDatasetMetadata
|
from .dataset_metadata import LeRobotDatasetMetadata
|
||||||
from .feature_utils import get_hf_features_from_features
|
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
|
||||||
from .io_utils import (
|
from .io_utils import (
|
||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
@@ -46,11 +49,54 @@ from .utils import (
|
|||||||
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
||||||
|
|
||||||
|
|
||||||
|
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
|
||||||
|
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_metadata: List of LeRobotDatasetMetadata objects to merge.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary of merged video feature info.
|
||||||
|
"""
|
||||||
|
merged_info = copy.deepcopy(all_metadata[0].features)
|
||||||
|
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
|
||||||
|
|
||||||
|
for vk in video_keys:
|
||||||
|
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
|
||||||
|
base_video_info = video_infos[0]
|
||||||
|
|
||||||
|
merged_encoder_info: dict = {}
|
||||||
|
fallback_keys: list[str] = []
|
||||||
|
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||||
|
values = [info.get(info_key, None) for info in video_infos]
|
||||||
|
first_value = values[0]
|
||||||
|
all_match = all(v == first_value for v in values[1:])
|
||||||
|
|
||||||
|
if all_match:
|
||||||
|
merged_encoder_info[info_key] = first_value
|
||||||
|
else:
|
||||||
|
fallback_keys.append(info_key)
|
||||||
|
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
|
||||||
|
|
||||||
|
if fallback_keys:
|
||||||
|
logging.warning(
|
||||||
|
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
|
||||||
|
f"Setting these keys to null (video.extra_options -> {fallback_keys}).",
|
||||||
|
)
|
||||||
|
|
||||||
|
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
|
||||||
|
# TODO(CarolinePascal): make this variable once we have support for other video backends.
|
||||||
|
merged_info[vk]["info"]["video.video_backend"] = "pyav"
|
||||||
|
|
||||||
|
return merged_info
|
||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
"""Validates that all dataset metadata have consistent properties.
|
"""Validates that all dataset metadata have consistent properties.
|
||||||
|
|
||||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||||
compatibility when aggregating them into a single dataset.
|
compatibility when aggregating them into a single dataset.
|
||||||
|
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||||
@@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||||
)
|
)
|
||||||
if features != meta.features:
|
if not features_equal_for_merge(features, meta.features):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||||
)
|
)
|
||||||
@@ -274,7 +320,8 @@ def aggregate_datasets(
|
|||||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, _ = validate_all_metadata(all_metadata)
|
||||||
|
features = merge_video_feature_info_for_aggregate(all_metadata)
|
||||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
|
||||||
dst_meta = LeRobotDatasetMetadata.create(
|
dst_meta = LeRobotDatasetMetadata.create(
|
||||||
@@ -413,6 +460,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||||
|
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
|
||||||
concatenate_video_files(
|
concatenate_video_files(
|
||||||
[dst_path, src_path],
|
[dst_path, src_path],
|
||||||
dst_path,
|
dst_path,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||||
|
|
||||||
@@ -108,6 +109,42 @@ def create_empty_dataset_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
|
||||||
|
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
|
||||||
|
|
||||||
|
For video features, keys under ``info`` related to video encoding parameters are ignored during
|
||||||
|
comparison as they do not prevent aggregation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _without_encoder_info_keys(feature: dict) -> dict:
|
||||||
|
filtered = dict(feature)
|
||||||
|
filtered_info = filtered.get("info")
|
||||||
|
if isinstance(filtered_info, dict):
|
||||||
|
filtered["info"] = {
|
||||||
|
info_key: info_value
|
||||||
|
for info_key, info_value in filtered_info.items()
|
||||||
|
if info_key not in VIDEO_ENCODER_INFO_KEYS
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
if set(features_a) != set(features_b):
|
||||||
|
return False
|
||||||
|
for key in features_a:
|
||||||
|
fa_key = features_a[key]
|
||||||
|
fb_key = features_b[key]
|
||||||
|
if fa_key.get("dtype") != fb_key.get("dtype"):
|
||||||
|
return False
|
||||||
|
if fa_key.get("dtype") != "video":
|
||||||
|
if fa_key != fb_key:
|
||||||
|
return False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
|
||||||
|
raise ValueError(f"Features {fa_key} and {fb_key} are not equal")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_delta_timestamps(
|
def check_delta_timestamps(
|
||||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -23,7 +25,9 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
|||||||
import datasets # noqa: E402
|
import datasets # noqa: E402
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
|
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
@@ -117,8 +121,9 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
|||||||
"Robot type should be the same"
|
"Robot type should be the same"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test features are the same
|
# Schema matches; merged video ``info`` is reconciled separately from per-source ``info``.
|
||||||
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
assert features_equal_for_merge(aggr_ds.features, ds_0.features)
|
||||||
|
assert features_equal_for_merge(aggr_ds.features, ds_1.features)
|
||||||
|
|
||||||
# Test tasks aggregation
|
# Test tasks aggregation
|
||||||
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
||||||
@@ -284,6 +289,73 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
|||||||
assert_dataset_iteration_works(aggr_ds)
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
|
||||||
|
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||||
|
tmp_path, lerobot_dataset_factory, caplog, mutation
|
||||||
|
):
|
||||||
|
"""Mismatched or missing encoder ``info`` is merged per-key with fallbacks and a warning."""
|
||||||
|
suffix = "enc_mismatch" if mutation == "mismatched_value" else "enc_missing"
|
||||||
|
ds_0 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / f"{suffix}_a",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_{suffix}_a",
|
||||||
|
total_episodes=2,
|
||||||
|
total_frames=20,
|
||||||
|
)
|
||||||
|
ds_1 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / f"{suffix}_b",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_{suffix}_b",
|
||||||
|
total_episodes=2,
|
||||||
|
total_frames=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
info_path = ds_1.root / "meta" / "info.json"
|
||||||
|
data = json.loads(info_path.read_text())
|
||||||
|
for ft in data["features"].values():
|
||||||
|
if ft.get("dtype") != "video":
|
||||||
|
continue
|
||||||
|
inf = ft.setdefault("info", {})
|
||||||
|
if mutation == "mismatched_value":
|
||||||
|
inf["video.crf"] = 99
|
||||||
|
inf["video.extra_options"] = {"tune": "film"}
|
||||||
|
else:
|
||||||
|
inf.pop("video.crf", None)
|
||||||
|
inf.pop("video.extra_options", None)
|
||||||
|
info_path.write_text(json.dumps(data))
|
||||||
|
|
||||||
|
aggr_id = f"{DUMMY_REPO_ID}_{suffix}_aggr"
|
||||||
|
aggr_root = tmp_path / f"{suffix}_aggr"
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||||
|
roots=[ds_0.root, ds_1.root],
|
||||||
|
aggr_repo_id=aggr_id,
|
||||||
|
aggr_root=aggr_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "heterogeneous" in caplog.text.lower() or "incomplete" in caplog.text.lower()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(aggr_root)
|
||||||
|
aggr_ds = LeRobotDataset(aggr_id, root=aggr_root)
|
||||||
|
|
||||||
|
for key, ft in aggr_ds.meta.info.features.items():
|
||||||
|
if ft.get("dtype") != "video":
|
||||||
|
continue
|
||||||
|
info = ft["info"]
|
||||||
|
reference = ds_0.meta.info.features[key]["info"]
|
||||||
|
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||||
|
if info_key == "video.crf":
|
||||||
|
assert info[info_key] is None
|
||||||
|
elif info_key == "video.extra_options":
|
||||||
|
assert info[info_key] == {}
|
||||||
|
else:
|
||||||
|
assert info[info_key] == reference[info_key]
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user