mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
Fixes aggregation of image datasets (#2717)
* fix: use features when aggregating image based datasets * add: test asserting for data type * add: features param to writing dataset --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
66929c5935
commit
b2ff219624
@@ -19,6 +19,7 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
@@ -32,6 +33,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
@@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
contains_images = len(dst_meta.image_keys) > 0
|
||||
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read source data to preserve image format
|
||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||
df = src_ds.to_pandas()
|
||||
else:
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
@@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
return data_idx
|
||||
@@ -488,6 +500,7 @@ def append_or_create_parquet_file(
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
):
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -503,6 +516,7 @@ def append_or_create_parquet_file(
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
@@ -512,7 +526,7 @@ def append_or_create_parquet_file(
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
@@ -527,12 +541,17 @@ def append_or_create_parquet_file(
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
if contains_images:
|
||||
# Use HuggingFace datasets to read existing data to preserve image format
|
||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||
existing_df = existing_ds.to_pandas()
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -1172,12 +1172,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
|
||||
|
||||
This verifies the fix for a bug where image columns were written with a generic
|
||||
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
|
||||
the proper Image() feature type, causing HuggingFace Hub viewer to display
|
||||
raw dict objects instead of image thumbnails.
|
||||
"""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
# Check that parquet files have proper Image schema
|
||||
data_dir = aggr_ds.root / "data"
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
|
||||
|
||||
for parquet_file in parquet_files:
|
||||
# Load with HuggingFace datasets to check schema
|
||||
ds = datasets.Dataset.from_parquet(str(parquet_file))
|
||||
|
||||
for image_key in image_keys:
|
||||
feature = ds.features.get(image_key)
|
||||
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
|
||||
assert isinstance(feature, datasets.Image), (
|
||||
f"Image key '{image_key}' should have Image() feature type, "
|
||||
f"but got {type(feature).__name__}: {feature}. "
|
||||
"This indicates image schema was not preserved during aggregation."
|
||||
)
|
||||
|
||||
|
||||
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that image frames are correctly preserved after aggregation."""
|
||||
image_keys = aggr_ds.meta.image_keys
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
def images_equal(img1, img2):
|
||||
return torch.allclose(img1, img2)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in image_keys:
|
||||
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||
f"Image frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
This test specifically verifies that:
|
||||
1. Image-based datasets can be aggregated correctly
|
||||
2. The HuggingFace Image() feature type is preserved in parquet files
|
||||
3. Image data integrity is maintained across aggregation
|
||||
4. Images can be properly decoded after aggregation
|
||||
|
||||
This catches the bug where to_parquet_with_hf_images() was not passing
|
||||
the features schema, causing image columns to be written as generic
|
||||
struct types instead of Image() types.
|
||||
"""
|
||||
ds_0_num_frames = 50
|
||||
ds_1_num_frames = 75
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
|
||||
aggr_root=tmp_path / "image_aggr",
|
||||
)
|
||||
|
||||
# Load the aggregated dataset
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
|
||||
|
||||
# Verify aggregated dataset has image keys
|
||||
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
|
||||
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
|
||||
|
||||
# Run standard aggregation assertions
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
for image_key in aggr_ds.meta.image_keys:
|
||||
img = sample_item[image_key]
|
||||
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
|
||||
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
Reference in New Issue
Block a user