Files
lerobot/tests/datasets/test_aggregate.py
T
2026-06-15 17:56:31 +02:00

805 lines
34 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from unittest.mock import patch
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import datasets # noqa: E402
import torch
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.feature_utils import features_equal_for_merge
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import (
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_REPO_ID,
)
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
)
assert aggr_ds.num_frames == expected_frames, (
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
)
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
keys_to_ignore = ["episode_index", "index", "timestamp"]
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
aggr_first_item = aggr_ds[0]
ds_0_first_item = ds_0[0]
# Compare all keys except episode_index and index which should be updated
for key in ds_0_first_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
f"First item key '{key}' doesn't match between aggregated and ds_0"
)
else:
assert aggr_first_item[key] == ds_0_first_item[key], (
f"First item key '{key}' doesn't match between aggregated and ds_0"
)
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
ds_0_last_item = ds_0[-1]
for key in ds_0_last_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
)
else:
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
)
# Test second part of dataset corresponds to ds_1
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
ds_1_first_item = ds_1[0]
for key in ds_1_first_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
)
else:
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
)
# Check last item matches ds_1[-1]
aggr_last_item = aggr_ds[-1]
ds_1_last_item = ds_1[-1]
for key in ds_1_last_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
f"Last item key '{key}' doesn't match between aggregated and ds_1"
)
else:
assert aggr_last_item[key] == ds_1_last_item[key], (
f"Last item key '{key}' doesn't match between aggregated and ds_1"
)
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
"""Test that metadata is correctly aggregated."""
# Test basic info
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
"Robot type should be the same"
)
# Schema matches; merged video ``info`` is reconciled separately from per-source ``info``.
assert features_equal_for_merge(aggr_ds.features, ds_0.features)
assert features_equal_for_merge(aggr_ds.features, ds_1.features)
# Test tasks aggregation
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
actual_tasks = set(aggr_ds.meta.tasks.index)
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
"""Test that episode indices are correctly updated after aggregation."""
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
for i in range(len(ds_0)):
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
)
def ds1_episodes_condition(ep_idx):
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
expected_min_episode_idx = ds_0.num_episodes
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
)
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that video frames are correctly preserved and frame indices are updated."""
def visual_frames_equal(frame1, frame2):
return torch.allclose(frame1, frame2)
video_keys = list(
filter(
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
aggr_ds.meta.info.features.keys(),
)
)
# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in video_keys:
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Visual frames at position {i} should be equal between aggregated and ds_0"
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
# The frame index in the aggregated dataset should also match its position.
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in video_keys:
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Visual frames at position {i} should be equal between aggregated and ds_1"
)
def assert_dataset_iteration_works(aggr_ds):
"""Test that we can iterate through the entire dataset without errors."""
for _ in aggr_ds:
pass
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
"""Test that depth keys are correctly preserved after aggregation.
Ensures that the ``is_depth_map`` marker on visual features survives
aggregation, so that downstream consumers (e.g. the dataset reader's
depth decoding path) keep working on the merged dataset.
"""
expected_depth_keys = set(ds_0.meta.depth_keys)
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
"Source datasets disagree on depth_keys; test setup is inconsistent"
)
actual_depth_keys = set(aggr_ds.meta.depth_keys)
assert actual_depth_keys == expected_depth_keys, (
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
)
for key in expected_depth_keys:
info = aggr_ds.meta.info.features[key].get("info") or {}
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
def assert_video_timestamps_within_bounds(aggr_ds):
"""Test that all video timestamps are within valid bounds for their respective video files.
This catches bugs where timestamps point to frames beyond the actual video length,
which would cause "Invalid frame index" errors during data loading.
"""
try:
from torchcodec.decoders import VideoDecoder
except ImportError:
return
for ep_idx in range(aggr_ds.num_episodes):
ep = aggr_ds.meta.episodes[ep_idx]
for vid_key in aggr_ds.meta.video_keys:
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
continue
from_frame_idx = round(from_ts * aggr_ds.fps)
to_frame_idx = round(to_ts * aggr_ds.fps)
try:
decoder = VideoDecoder(str(video_path))
num_frames = len(decoder)
# Verify timestamps don't exceed video bounds
assert from_frame_idx >= 0, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
)
assert from_frame_idx < num_frames, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
)
assert to_frame_idx <= num_frames, (
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
)
assert from_frame_idx < to_frame_idx, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
)
except Exception as e:
raise AssertionError(
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
) from e
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters.
Source datasets include both RGB and depth video features so the same
aggregation flow is exercised on the ``is_depth_map`` branch.
"""
ds_0_num_frames = 400
ds_1_num_frames = 800
ds_0_num_episodes = 10
ds_1_num_episodes = 25
# Create two datasets with different number of frames and episodes
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
# Confirm depth was actually wired into the source datasets so the
# rest of the assertions exercise the depth aggregation path.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr",
)
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Run all assertion functions
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
expected_total_frames = ds_0.num_frames + ds_1.num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory):
"""With concatenation disabled, each source file is kept as its own destination file."""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_0",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_0",
total_episodes=3,
total_frames=60,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_1",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_1",
total_episodes=4,
total_frames=80,
)
aggr_root = tmp_path / "no_stitch_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr",
aggr_root=aggr_root,
concatenate_videos=False,
concatenate_data=False,
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(aggr_root)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root)
assert_episode_and_frame_counts(
aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames
)
assert_dataset_iteration_works(aggr_ds)
assert_video_timestamps_within_bounds(aggr_ds)
# Two single-file sources stay as two files each, instead of being packed together.
assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2
assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature"
for key in aggr_ds.meta.video_keys:
assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
tmp_path, lerobot_dataset_factory, caplog, mutation
):
"""Mismatched or missing encoder ``info`` is merged per-key with fallbacks and a warning."""
suffix = "enc_mismatch" if mutation == "mismatched_value" else "enc_missing"
ds_0 = lerobot_dataset_factory(
root=tmp_path / f"{suffix}_a",
repo_id=f"{DUMMY_REPO_ID}_{suffix}_a",
total_episodes=2,
total_frames=20,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / f"{suffix}_b",
repo_id=f"{DUMMY_REPO_ID}_{suffix}_b",
total_episodes=2,
total_frames=20,
)
info_path = ds_1.root / "meta" / "info.json"
data = json.loads(info_path.read_text())
for ft in data["features"].values():
if ft.get("dtype") != "video":
continue
inf = ft.setdefault("info", {})
if mutation == "mismatched_value":
inf["video.crf"] = 99
inf["video.extra_options"] = {"tune": "film"}
else:
inf.pop("video.crf", None)
inf.pop("video.extra_options", None)
info_path.write_text(json.dumps(data))
aggr_id = f"{DUMMY_REPO_ID}_{suffix}_aggr"
aggr_root = tmp_path / f"{suffix}_aggr"
with caplog.at_level(logging.WARNING):
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=aggr_id,
aggr_root=aggr_root,
)
assert "heterogeneous" in caplog.text.lower() or "incomplete" in caplog.text.lower()
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(aggr_root)
aggr_ds = LeRobotDataset(aggr_id, root=aggr_root)
for key, ft in aggr_ds.meta.info.features.items():
if ft.get("dtype") != "video":
continue
info = ft["info"]
reference = ds_0.meta.info.features[key]["info"]
for info_key in VIDEO_ENCODER_INFO_KEYS:
if info_key == "video.crf":
assert info[info_key] is None
elif info_key == "video.extra_options":
assert info[info_key] == {}
else:
assert info[info_key] == reference[info_key]
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
"""Test aggregation with small file size limits to force file rotation/sharding.
Depth video features are included to verify that file rotation/concat
correctly handles depth-marked features alongside regular RGB ones.
"""
ds_0_num_episodes = ds_1_num_episodes = 10
ds_0_num_frames = ds_1_num_frames = 400
ds_0 = lerobot_dataset_factory(
root=tmp_path / "small_0",
repo_id=f"{DUMMY_REPO_ID}_small_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "small_1",
repo_id=f"{DUMMY_REPO_ID}_small_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Use the new configurable parameters to force file rotation
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
aggr_root=tmp_path / "small_aggr",
# Tiny file size to trigger new file instantiation
data_files_size_in_mb=0.01,
video_files_size_in_mb=0.1,
)
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Verify aggregation worked correctly despite file size constraints
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
data_dir = tmp_path / "small_aggr" / "data"
video_dir = tmp_path / "small_aggr" / "videos"
if data_dir.exists():
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
"""Regression test for video timestamp bug when merging datasets.
This test specifically checks that video timestamps are correctly calculated
and accumulated when merging multiple datasets. Depth video features are
included so depth timestamps are also covered by the regression.
"""
datasets = []
for i in range(3):
ds = lerobot_dataset_factory(
root=tmp_path / f"regression_{i}",
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
total_episodes=2,
total_frames=100,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
datasets.append(ds)
for i, ds in enumerate(datasets):
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
aggregate_datasets(
repo_ids=[ds.repo_id for ds in datasets],
roots=[ds.root for ds in datasets],
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
aggr_root=tmp_path / "regression_aggr",
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
assert_video_timestamps_within_bounds(aggr_ds)
# Depth keys must survive the merge for the regression to cover the
# ``is_depth_map`` decoding branch.
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
depth_keys = set(aggr_ds.meta.depth_keys)
for i in range(len(aggr_ds)):
item = aggr_ds[i]
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
# Depth frames are single-channel (1, H, W) after dequantization;
# standard RGB frames keep the 3-channel layout.
expected_channels = 1 if key in depth_keys else 3
assert item[key].shape[0] == expected_channels, (
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
)
def assert_image_schema_preserved(aggr_ds):
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
This verifies the fix for a bug where image columns were written with a generic
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
the proper Image() feature type, causing HuggingFace Hub viewer to display
raw dict objects instead of image thumbnails.
"""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
# Check that parquet files have proper Image schema
data_dir = aggr_ds.root / "data"
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
for parquet_file in parquet_files:
# Load with HuggingFace datasets to check schema
ds = datasets.Dataset.from_parquet(str(parquet_file))
for image_key in image_keys:
feature = ds.features.get(image_key)
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
assert isinstance(feature, datasets.Image), (
f"Image key '{image_key}' should have Image() feature type, "
f"but got {type(feature).__name__}: {feature}. "
"This indicates image schema was not preserved during aggregation."
)
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that image frames are correctly preserved after aggregation."""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return
def images_equal(img1, img2):
return torch.allclose(img1, img2)
# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_0"
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_1"
)
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
This test specifically verifies that:
1. Image-based datasets can be aggregated correctly
2. The HuggingFace Image() feature type is preserved in parquet files
3. Image data integrity is maintained across aggregation
4. Images can be properly decoded after aggregation
This catches the bug where to_parquet_with_hf_images() was not passing
the features schema, causing image columns to be written as generic
struct types instead of Image() types.
"""
ds_0_num_frames = 50
ds_1_num_frames = 75
ds_0_num_episodes = 2
ds_1_num_episodes = 3
# Create two image-based datasets (use_videos=False) with a mix of RGB
# and depth-marked cameras so the depth path is exercised in image mode.
ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
use_videos=False,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
use_videos=False,
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
)
# Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
# And that the depth marker actually made it onto an image feature.
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
# Aggregate the datasets
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
aggr_root=tmp_path / "image_aggr",
)
# Load the aggregated dataset
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
# Verify aggregated dataset has image keys
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
# Run standard aggregation assertions
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
# Image-specific assertions
assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
# Verify images can be accessed and have correct shape
sample_item = aggr_ds[0]
for image_key in aggr_ds.meta.image_keys:
img = sample_item[image_key]
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
This test reproduces the bug where merging datasets with multiple parquet files
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
because metadata file indices were incorrectly preserved instead of being mapped
to their actual destination files.
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
file indices to destination file indices.
"""
# Step 1: Create datasets A and B
ds_a = lerobot_dataset_factory(
root=tmp_path / "ds_a",
repo_id=f"{DUMMY_REPO_ID}_a",
total_episodes=4,
total_frames=200,
)
ds_b = lerobot_dataset_factory(
root=tmp_path / "ds_b",
repo_id=f"{DUMMY_REPO_ID}_b",
total_episodes=4,
total_frames=200,
)
# Step 2: Merge A+B into AB with small file size to force multiple files
aggregate_datasets(
repo_ids=[ds_a.repo_id, ds_b.repo_id],
roots=[ds_a.root, ds_b.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
aggr_root=tmp_path / "ds_ab",
data_files_size_in_mb=0.01, # Force file rotation
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
# Verify AB has multiple data files (file rotation occurred)
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
# Step 3: Create dataset C
ds_c = lerobot_dataset_factory(
root=tmp_path / "ds_c",
repo_id=f"{DUMMY_REPO_ID}_c",
total_episodes=2,
total_frames=100,
)
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
aggregate_datasets(
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
roots=[ds_ab.root, ds_c.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
aggr_root=tmp_path / "ds_abc",
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
# Step 5: Verify all data files referenced in metadata actually exist
for ep_idx in range(ds_abc.num_episodes):
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
assert data_file_path.exists(), (
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
"This indicates the src_to_dst mapping fix is not working correctly."
)
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
assert ds_abc.num_episodes == expected_episodes
assert ds_abc.num_frames == expected_frames
# This would raise FileNotFoundError before the fix
assert_dataset_iteration_works(ds_abc)