mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Fixes aggregation of image datasets (#2717)
* fix: use features when aggregating image based datasets * add: test asserting for data type * add: features param to writing dataset --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
66929c5935
commit
b2ff219624
@@ -16,6 +16,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
|
||||
|
||||
This verifies the fix for a bug where image columns were written with a generic
|
||||
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
|
||||
the proper Image() feature type, causing HuggingFace Hub viewer to display
|
||||
raw dict objects instead of image thumbnails.
|
||||
"""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
# Check that parquet files have proper Image schema
|
||||
data_dir = aggr_ds.root / "data"
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
|
||||
|
||||
for parquet_file in parquet_files:
|
||||
# Load with HuggingFace datasets to check schema
|
||||
ds = datasets.Dataset.from_parquet(str(parquet_file))
|
||||
|
||||
for image_key in image_keys:
|
||||
feature = ds.features.get(image_key)
|
||||
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
|
||||
assert isinstance(feature, datasets.Image), (
|
||||
f"Image key '{image_key}' should have Image() feature type, "
|
||||
f"but got {type(feature).__name__}: {feature}. "
|
||||
"This indicates image schema was not preserved during aggregation."
|
||||
)
|
||||
|
||||
|
||||
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that image frames are correctly preserved after aggregation."""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
def images_equal(img1, img2):
|
||||
return torch.allclose(img1, img2)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
This test specifically verifies that:
|
||||
1. Image-based datasets can be aggregated correctly
|
||||
2. The HuggingFace Image() feature type is preserved in parquet files
|
||||
3. Image data integrity is maintained across aggregation
|
||||
4. Images can be properly decoded after aggregation
|
||||
|
||||
This catches the bug where to_parquet_with_hf_images() was not passing
|
||||
the features schema, causing image columns to be written as generic
|
||||
struct types instead of Image() types.
|
||||
"""
|
||||
ds_0_num_frames = 50
|
||||
ds_1_num_frames = 75
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
|
||||
aggr_root=tmp_path / "image_aggr",
|
||||
)
|
||||
|
||||
# Load the aggregated dataset
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
|
||||
|
||||
# Verify aggregated dataset has image keys
|
||||
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
|
||||
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
|
||||
|
||||
# Run standard aggregation assertions
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
for image_key in aggr_ds.meta.image_keys:
|
||||
img = sample_item[image_key]
|
||||
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
|
||||
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
Reference in New Issue
Block a user