From 736b43f3cfb5db2450fa787a45f645e1309caa00 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 28 Jan 2026 13:31:27 +0100 Subject: [PATCH] Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861) * Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies --- src/lerobot/datasets/aggregate.py | 100 ++++++++++++++++++++++++------ tests/datasets/test_aggregate.py | 89 ++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 18 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 94ffe602e..7020545d2 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -116,6 +116,9 @@ def update_meta_data( Adjusts all indices and timestamps to account for previously aggregated data and videos in the destination dataset. + For data file indices, uses the 'src_to_dst' mapping from aggregate_data() + to correctly map source file indices to their destination locations. + Args: df: DataFrame containing the metadata to be updated. dst_meta: Destination dataset metadata. @@ -129,8 +132,50 @@ def update_meta_data( df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] - df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] - df["data/file_index"] = df["data/file_index"] + data_idx["file"] + + # Update data file indices using source-to-destination mapping + # This is critical for handling datasets that are already results of a merge + data_src_to_dst = data_idx.get("src_to_dst", {}) + if data_src_to_dst: + # Store original indices for lookup + df["_orig_data_chunk"] = df["data/chunk_index"].copy() + df["_orig_data_file"] = df["data/file_index"].copy() + + # Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file) + # This is much faster than per-row iteration for large metadata tables + mapping_index = pd.MultiIndex.from_tuples( + list(data_src_to_dst.keys()), + names=["chunk_index", "file_index"], + ) + mapping_values = list(data_src_to_dst.values()) + mapping_df = pd.DataFrame( + mapping_values, + index=mapping_index, + columns=["dst_chunk", "dst_file"], + ) + + # Construct a MultiIndex for each row based on original data indices + row_index = pd.MultiIndex.from_arrays( + [df["_orig_data_chunk"], df["_orig_data_file"]], + names=["chunk_index", "file_index"], + ) + + # Align mapping to rows; missing keys fall back to the default destination + reindexed = mapping_df.reindex(row_index) + reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna( + {"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]} + ) + + # Assign mapped destination indices back to the DataFrame + df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy() + df["data/file_index"] = reindexed["dst_file"].to_numpy() + + # Clean up temporary columns + df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"]) + else: + # Fallback to simple offset (backward compatibility for single-file sources) + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): # Store original video file indices before updating orig_chunk_col = f"videos/{key}/chunk_index" @@ -146,8 +191,7 @@ def update_meta_data( if src_to_dst: # Map each episode to its correct destination file and apply offset for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) # Get destination chunk/file for this source file dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"])) @@ -163,8 +207,7 @@ def update_meta_data( df[orig_chunk_col] = video_idx["chunk"] df[orig_file_col] = video_idx["file"] for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) offset = src_to_offset.get(src_key, 0) df.at[idx, f"videos/{key}/from_timestamp"] += offset df.at[idx, f"videos/{key}/to_timestamp"] += offset @@ -262,6 +305,10 @@ def aggregate_datasets( meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + # Clear the src_to_dst mapping after processing each source dataset + # to avoid interference between different source datasets + data_idx.pop("src_to_dst", None) + dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu dst_file_durations = video_idx["dst_file_durations"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: - # Convert to Python int to ensure consistent dict keys - src_chunk_idx = int(src_chunk_idx) - src_file_idx = int(src_file_idx) - src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, chunk_index=src_chunk_idx, @@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si Reads source data files, updates indices to match the aggregated dataset, and writes them to the destination with proper file rotation. + Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file) + which is critical for correctly updating episode metadata when source datasets + have multiple data files (e.g., from a previous merge operation). + Args: src_meta: Source dataset metadata. dst_meta: Destination dataset metadata. data_idx: Dictionary tracking data chunk and file indices. + data_files_size_in_mb: Maximum size for data files in MB. + chunk_size: Maximum number of files per chunk. Returns: dict: Updated data_idx with current chunk and file indices. @@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si # retrieve features schema for proper image typing in parquet hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None + # Track source to destination file mapping for metadata update + # This is critical for handling datasets that are already results of a merge + src_to_dst: dict[tuple[int, int], tuple[int, int]] = {} + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx @@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - data_idx = append_or_create_parquet_file( + # Write data and get the actual destination file it was written to + # This avoids duplicating the rotation logic here + data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file( df, src_path, data_idx, @@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si hf_features=hf_features, ) + # Record the mapping from source to actual destination + src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file) + + # Add the mapping to data_idx for use in metadata update + data_idx["src_to_dst"] = src_to_dst + return data_idx @@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - meta_idx = append_or_create_parquet_file( + meta_idx, _ = append_or_create_parquet_file( df, src_path, meta_idx, @@ -501,7 +562,7 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, hf_features: datasets.Features | None = None, -): +) -> tuple[dict[str, int], tuple[int, int]]: """Appends data to an existing parquet file or creates a new one based on size constraints. Manages file rotation when size limits are exceeded to prevent individual files @@ -519,9 +580,11 @@ def append_or_create_parquet_file( hf_features: Optional HuggingFace Features schema for proper image typing. Returns: - dict: Updated index dictionary with current chunk and file indices. + tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict + and (dst_chunk, dst_file) is the actual destination file the data was written to. """ - dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) @@ -529,14 +592,15 @@ def append_or_create_parquet_file( to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) - return idx + return idx, (dst_chunk, dst_file) src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) if dst_size + src_size >= max_mb: idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) - new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) new_path.parent.mkdir(parents=True, exist_ok=True) final_df = df target_path = new_path @@ -555,7 +619,7 @@ def append_or_create_parquet_file( else: final_df.to_parquet(target_path) - return idx + return idx, (dst_chunk, dst_file) def finalize_aggregation(aggr_meta, all_metadata): diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 031c29d60..3609bac24 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): + """Regression test for aggregating a dataset that is itself a result of a previous merge. + + This test reproduces the bug where merging datasets with multiple parquet files + (e.g., from a previous merge with file rotation) would cause FileNotFoundError + because metadata file indices were incorrectly preserved instead of being mapped + to their actual destination files. + + The fix adds src_to_dst tracking in aggregate_data() to correctly map source + file indices to destination file indices. + """ + # Step 1: Create datasets A and B + ds_a = lerobot_dataset_factory( + root=tmp_path / "ds_a", + repo_id=f"{DUMMY_REPO_ID}_a", + total_episodes=4, + total_frames=200, + ) + ds_b = lerobot_dataset_factory( + root=tmp_path / "ds_b", + repo_id=f"{DUMMY_REPO_ID}_b", + total_episodes=4, + total_frames=200, + ) + + # Step 2: Merge A+B into AB with small file size to force multiple files + aggregate_datasets( + repo_ids=[ds_a.repo_id, ds_b.repo_id], + roots=[ds_a.root, ds_b.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_ab", + aggr_root=tmp_path / "ds_ab", + data_files_size_in_mb=0.01, # Force file rotation + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_ab") + ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab") + + # Verify AB has multiple data files (file rotation occurred) + ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet")) + assert len(ab_data_files) > 1, "First merge should create multiple parquet files" + + # Step 3: Create dataset C + ds_c = lerobot_dataset_factory( + root=tmp_path / "ds_c", + repo_id=f"{DUMMY_REPO_ID}_c", + total_episodes=2, + total_frames=100, + ) + + # Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED + aggregate_datasets( + repo_ids=[ds_ab.repo_id, ds_c.repo_id], + roots=[ds_ab.root, ds_c.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_abc", + aggr_root=tmp_path / "ds_abc", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_abc") + ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc") + + # Step 5: Verify all data files referenced in metadata actually exist + for ep_idx in range(ds_abc.num_episodes): + data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx) + assert data_file_path.exists(), ( + f"Episode {ep_idx} references non-existent file: {data_file_path}\n" + "This indicates the src_to_dst mapping fix is not working correctly." + ) + + # Step 6: Verify we can iterate through the entire dataset without FileNotFoundError + expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes + expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames + + assert ds_abc.num_episodes == expected_episodes + assert ds_abc.num_frames == expected_frames + + # This would raise FileNotFoundError before the fix + assert_dataset_iteration_works(ds_abc)