diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 455caf0fe..94ffe602e 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -19,6 +19,7 @@ import logging import shutil from pathlib import Path +import datasets import pandas as pd import tqdm @@ -32,6 +33,7 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, get_file_size_in_mb, + get_hf_features_from_features, get_parquet_file_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, @@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si } unique_chunk_file_ids = sorted(unique_chunk_file_ids) + contains_images = len(dst_meta.image_keys) > 0 + + # retrieve features schema for proper image typing in parquet + hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx ) - df = pd.read_parquet(src_path) + if contains_images: + # Use HuggingFace datasets to read source data to preserve image format + src_ds = datasets.Dataset.from_parquet(str(src_path)) + df = src_ds.to_pandas() + else: + df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) data_idx = append_or_create_parquet_file( @@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si data_files_size_in_mb, chunk_size, DEFAULT_DATA_PATH, - contains_images=len(dst_meta.image_keys) > 0, + contains_images=contains_images, aggr_root=dst_meta.root, + hf_features=hf_features, ) return data_idx @@ -488,6 +500,7 @@ def append_or_create_parquet_file( default_path: str, contains_images: bool = False, aggr_root: Path = None, + hf_features: datasets.Features | None = None, ): """Appends data to an existing parquet file or creates a new one based on size constraints. @@ -503,6 +516,7 @@ def append_or_create_parquet_file( default_path: Format string for generating file paths. contains_images: Whether the data contains images requiring special handling. aggr_root: Root path for the aggregated dataset. + hf_features: Optional HuggingFace Features schema for proper image typing. Returns: dict: Updated index dictionary with current chunk and file indices. @@ -512,7 +526,7 @@ def append_or_create_parquet_file( if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) if contains_images: - to_parquet_with_hf_images(df, dst_path) + to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) return idx @@ -527,12 +541,17 @@ def append_or_create_parquet_file( final_df = df target_path = new_path else: - existing_df = pd.read_parquet(dst_path) + if contains_images: + # Use HuggingFace datasets to read existing data to preserve image format + existing_ds = datasets.Dataset.from_parquet(str(dst_path)) + existing_df = existing_ds.to_pandas() + else: + existing_df = pd.read_parquet(dst_path) final_df = pd.concat([existing_df, df], ignore_index=True) target_path = dst_path if contains_images: - to_parquet_with_hf_images(final_df, target_path) + to_parquet_with_hf_images(final_df, target_path, features=hf_features) else: final_df.to_parquet(target_path) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 234736a75..ed678af6e 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -1172,12 +1172,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) -def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. """ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) def item_to_torch(item: dict) -> dict: diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index b710a3a4b..031c29d60 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -16,6 +16,7 @@ from unittest.mock import patch +import datasets import torch from lerobot.datasets.aggregate import aggregate_datasets @@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): for key in aggr_ds.meta.video_keys: assert key in item, f"Video key {key} missing from item {i}" assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" + + +def assert_image_schema_preserved(aggr_ds): + """Test that HuggingFace Image feature schema is preserved in aggregated parquet files. + + This verifies the fix for a bug where image columns were written with a generic + struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of + the proper Image() feature type, causing HuggingFace Hub viewer to display + raw dict objects instead of image thumbnails. + """ + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + # Check that parquet files have proper Image schema + data_dir = aggr_ds.root / "data" + parquet_files = list(data_dir.rglob("*.parquet")) + assert len(parquet_files) > 0, "No parquet files found in aggregated dataset" + + for parquet_file in parquet_files: + # Load with HuggingFace datasets to check schema + ds = datasets.Dataset.from_parquet(str(parquet_file)) + + for image_key in image_keys: + feature = ds.features.get(image_key) + assert feature is not None, f"Image key '{image_key}' not found in parquet schema" + assert isinstance(feature, datasets.Image), ( + f"Image key '{image_key}' should have Image() feature type, " + f"but got {type(feature).__name__}: {feature}. " + "This indicates image schema was not preserved during aggregation." + ) + + +def assert_image_frames_integrity(aggr_ds, ds_0, ds_1): + """Test that image frames are correctly preserved after aggregation.""" + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + def images_equal(img1, img2): + return torch.allclose(img1, img2) + + # Test the section corresponding to the first dataset (ds_0) + for i in range(len(ds_0)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_0[i][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_0" + ) + + # Test the section corresponding to the second dataset (ds_1) + for i in range(len(ds_0), len(ds_0) + len(ds_1)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_1" + ) + + +def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): + """Test aggregation of image-based datasets preserves HuggingFace Image schema. + + This test specifically verifies that: + 1. Image-based datasets can be aggregated correctly + 2. The HuggingFace Image() feature type is preserved in parquet files + 3. Image data integrity is maintained across aggregation + 4. Images can be properly decoded after aggregation + + This catches the bug where to_parquet_with_hf_images() was not passing + the features schema, causing image columns to be written as generic + struct types instead of Image() types. + """ + ds_0_num_frames = 50 + ds_1_num_frames = 75 + ds_0_num_episodes = 2 + ds_1_num_episodes = 3 + + # Create two image-based datasets (use_videos=False) + ds_0 = lerobot_dataset_factory( + root=tmp_path / "image_0", + repo_id=f"{DUMMY_REPO_ID}_image_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + use_videos=False, # Image-based dataset + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "image_1", + repo_id=f"{DUMMY_REPO_ID}_image_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + use_videos=False, # Image-based dataset + ) + + # Verify source datasets have image keys + assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys" + assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys" + + # Aggregate the datasets + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr", + aggr_root=tmp_path / "image_aggr", + ) + + # Load the aggregated dataset + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "image_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr") + + # Verify aggregated dataset has image keys + assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys" + assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets" + + # Run standard aggregation assertions + expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes + expected_total_frames = ds_0_num_frames + ds_1_num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + + # Image-specific assertions + assert_image_schema_preserved(aggr_ds) + assert_image_frames_integrity(aggr_ds, ds_0, ds_1) + + # Verify images can be accessed and have correct shape + sample_item = aggr_ds[0] + for image_key in aggr_ds.meta.image_keys: + img = sample_item[image_key] + assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor" + assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)" + assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" + + assert_dataset_iteration_works(aggr_ds)