mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Fixes aggregation of image datasets (#2717)
* fix: use features when aggregating image based datasets * add: test asserting for data type * add: features param to writing dataset --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
66929c5935
commit
b2ff219624
@@ -19,6 +19,7 @@ import logging
|
|||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ from lerobot.datasets.utils import (
|
|||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
|
get_hf_features_from_features,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
to_parquet_with_hf_images,
|
to_parquet_with_hf_images,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
@@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
}
|
}
|
||||||
|
|
||||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||||
|
contains_images = len(dst_meta.image_keys) > 0
|
||||||
|
|
||||||
|
# retrieve features schema for proper image typing in parquet
|
||||||
|
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
)
|
)
|
||||||
df = pd.read_parquet(src_path)
|
if contains_images:
|
||||||
|
# Use HuggingFace datasets to read source data to preserve image format
|
||||||
|
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
||||||
|
df = src_ds.to_pandas()
|
||||||
|
else:
|
||||||
|
df = pd.read_parquet(src_path)
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
df = update_data_df(df, src_meta, dst_meta)
|
||||||
|
|
||||||
data_idx = append_or_create_parquet_file(
|
data_idx = append_or_create_parquet_file(
|
||||||
@@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
data_files_size_in_mb,
|
data_files_size_in_mb,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
contains_images=len(dst_meta.image_keys) > 0,
|
contains_images=contains_images,
|
||||||
aggr_root=dst_meta.root,
|
aggr_root=dst_meta.root,
|
||||||
|
hf_features=hf_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_idx
|
return data_idx
|
||||||
@@ -488,6 +500,7 @@ def append_or_create_parquet_file(
|
|||||||
default_path: str,
|
default_path: str,
|
||||||
contains_images: bool = False,
|
contains_images: bool = False,
|
||||||
aggr_root: Path = None,
|
aggr_root: Path = None,
|
||||||
|
hf_features: datasets.Features | None = None,
|
||||||
):
|
):
|
||||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||||
|
|
||||||
@@ -503,6 +516,7 @@ def append_or_create_parquet_file(
|
|||||||
default_path: Format string for generating file paths.
|
default_path: Format string for generating file paths.
|
||||||
contains_images: Whether the data contains images requiring special handling.
|
contains_images: Whether the data contains images requiring special handling.
|
||||||
aggr_root: Root path for the aggregated dataset.
|
aggr_root: Root path for the aggregated dataset.
|
||||||
|
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated index dictionary with current chunk and file indices.
|
dict: Updated index dictionary with current chunk and file indices.
|
||||||
@@ -512,7 +526,7 @@ def append_or_create_parquet_file(
|
|||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(df, dst_path)
|
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
df.to_parquet(dst_path)
|
||||||
return idx
|
return idx
|
||||||
@@ -527,12 +541,17 @@ def append_or_create_parquet_file(
|
|||||||
final_df = df
|
final_df = df
|
||||||
target_path = new_path
|
target_path = new_path
|
||||||
else:
|
else:
|
||||||
existing_df = pd.read_parquet(dst_path)
|
if contains_images:
|
||||||
|
# Use HuggingFace datasets to read existing data to preserve image format
|
||||||
|
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
||||||
|
existing_df = existing_ds.to_pandas()
|
||||||
|
else:
|
||||||
|
existing_df = pd.read_parquet(dst_path)
|
||||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
target_path = dst_path
|
target_path = dst_path
|
||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path)
|
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
|
|||||||
@@ -1172,12 +1172,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
def to_parquet_with_hf_images(
|
||||||
|
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||||
|
) -> None:
|
||||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame to write to parquet.
|
||||||
|
path: Path to write the parquet file.
|
||||||
|
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||||
|
are properly typed as Image() in the parquet schema.
|
||||||
"""
|
"""
|
||||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||||
|
ds.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
def item_to_torch(item: dict) -> dict:
|
def item_to_torch(item: dict) -> dict:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
@@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
|||||||
for key in aggr_ds.meta.video_keys:
|
for key in aggr_ds.meta.video_keys:
|
||||||
assert key in item, f"Video key {key} missing from item {i}"
|
assert key in item, f"Video key {key} missing from item {i}"
|
||||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_image_schema_preserved(aggr_ds):
|
||||||
|
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.
|
||||||
|
|
||||||
|
This verifies the fix for a bug where image columns were written with a generic
|
||||||
|
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
|
||||||
|
the proper Image() feature type, causing HuggingFace Hub viewer to display
|
||||||
|
raw dict objects instead of image thumbnails.
|
||||||
|
"""
|
||||||
|
image_keys = aggr_ds.meta.image_keys
|
||||||
|
if not image_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check that parquet files have proper Image schema
|
||||||
|
data_dir = aggr_ds.root / "data"
|
||||||
|
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||||
|
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"
|
||||||
|
|
||||||
|
for parquet_file in parquet_files:
|
||||||
|
# Load with HuggingFace datasets to check schema
|
||||||
|
ds = datasets.Dataset.from_parquet(str(parquet_file))
|
||||||
|
|
||||||
|
for image_key in image_keys:
|
||||||
|
feature = ds.features.get(image_key)
|
||||||
|
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
|
||||||
|
assert isinstance(feature, datasets.Image), (
|
||||||
|
f"Image key '{image_key}' should have Image() feature type, "
|
||||||
|
f"but got {type(feature).__name__}: {feature}. "
|
||||||
|
"This indicates image schema was not preserved during aggregation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||||
|
"""Test that image frames are correctly preserved after aggregation."""
|
||||||
|
image_keys = aggr_ds.meta.image_keys
|
||||||
|
if not image_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
def images_equal(img1, img2):
|
||||||
|
return torch.allclose(img1, img2)
|
||||||
|
|
||||||
|
# Test the section corresponding to the first dataset (ds_0)
|
||||||
|
for i in range(len(ds_0)):
|
||||||
|
assert aggr_ds[i]["index"] == i, (
|
||||||
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||||
|
)
|
||||||
|
for key in image_keys:
|
||||||
|
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||||
|
f"Image frames at position {i} should be equal between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the section corresponding to the second dataset (ds_1)
|
||||||
|
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||||
|
assert aggr_ds[i]["index"] == i, (
|
||||||
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||||
|
)
|
||||||
|
for key in image_keys:
|
||||||
|
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||||
|
f"Image frames at position {i} should be equal between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||||
|
|
||||||
|
This test specifically verifies that:
|
||||||
|
1. Image-based datasets can be aggregated correctly
|
||||||
|
2. The HuggingFace Image() feature type is preserved in parquet files
|
||||||
|
3. Image data integrity is maintained across aggregation
|
||||||
|
4. Images can be properly decoded after aggregation
|
||||||
|
|
||||||
|
This catches the bug where to_parquet_with_hf_images() was not passing
|
||||||
|
the features schema, causing image columns to be written as generic
|
||||||
|
struct types instead of Image() types.
|
||||||
|
"""
|
||||||
|
ds_0_num_frames = 50
|
||||||
|
ds_1_num_frames = 75
|
||||||
|
ds_0_num_episodes = 2
|
||||||
|
ds_1_num_episodes = 3
|
||||||
|
|
||||||
|
# Create two image-based datasets (use_videos=False)
|
||||||
|
ds_0 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "image_0",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||||
|
total_episodes=ds_0_num_episodes,
|
||||||
|
total_frames=ds_0_num_frames,
|
||||||
|
use_videos=False, # Image-based dataset
|
||||||
|
)
|
||||||
|
ds_1 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "image_1",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||||
|
total_episodes=ds_1_num_episodes,
|
||||||
|
total_frames=ds_1_num_frames,
|
||||||
|
use_videos=False, # Image-based dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify source datasets have image keys
|
||||||
|
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||||
|
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||||
|
|
||||||
|
# Aggregate the datasets
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||||
|
roots=[ds_0.root, ds_1.root],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
|
||||||
|
aggr_root=tmp_path / "image_aggr",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the aggregated dataset
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")
|
||||||
|
|
||||||
|
# Verify aggregated dataset has image keys
|
||||||
|
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
|
||||||
|
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"
|
||||||
|
|
||||||
|
# Run standard aggregation assertions
|
||||||
|
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||||
|
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||||
|
|
||||||
|
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||||
|
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||||
|
|
||||||
|
# Image-specific assertions
|
||||||
|
assert_image_schema_preserved(aggr_ds)
|
||||||
|
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
|
||||||
|
# Verify images can be accessed and have correct shape
|
||||||
|
sample_item = aggr_ds[0]
|
||||||
|
for image_key in aggr_ds.meta.image_keys:
|
||||||
|
img = sample_item[image_key]
|
||||||
|
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
|
||||||
|
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
|
||||||
|
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||||
|
|
||||||
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
|
|||||||
Reference in New Issue
Block a user