Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)

* Fix aggeregation of datasets when subdatasets are already a result of a previous merge

* docstring

* respond to copilot review + add regression test

* Remove unnecessary int conversion for indicies
This commit is contained in:
Michel Aractingi
2026-01-28 13:31:27 +01:00
committed by GitHub
parent f6b1c39b78
commit 736b43f3cf
2 changed files with 171 additions and 18 deletions
+80 -16
View File
@@ -116,6 +116,9 @@ def update_meta_data(
Adjusts all indices and timestamps to account for previously aggregated Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset. data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args: Args:
df: DataFrame containing the metadata to be updated. df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata. dst_meta: Destination dataset metadata.
@@ -129,6 +132,48 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"] df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items(): for key, video_idx in videos_idx.items():
@@ -146,8 +191,7 @@ def update_meta_data(
if src_to_dst: if src_to_dst:
# Map each episode to its correct destination file and apply offset # Map each episode to its correct destination file and apply offset
for idx in df.index: for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
# Get destination chunk/file for this source file # Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"])) dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
@@ -163,8 +207,7 @@ def update_meta_data(
df[orig_chunk_col] = video_idx["chunk"] df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"] df[orig_file_col] = video_idx["file"]
for idx in df.index: for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
offset = src_to_offset.get(src_key, 0) offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset df.at[idx, f"videos/{key}/to_timestamp"] += offset
@@ -262,6 +305,10 @@ def aggregate_datasets(
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames dst_meta.info["total_frames"] += src_meta.total_frames
@@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_file_durations = video_idx["dst_file_durations"] dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key, video_key=key,
chunk_index=src_chunk_idx, chunk_index=src_chunk_idx,
@@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
Reads source data files, updates indices to match the aggregated dataset, Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation. and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args: Args:
src_meta: Source dataset metadata. src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata. dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices. data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns: Returns:
dict: Updated data_idx with current chunk and file indices. dict: Updated data_idx with current chunk and file indices.
@@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
# retrieve features schema for proper image typing in parquet # retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids: for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format( src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path) df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta) df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file( # Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df, df,
src_path, src_path,
data_idx, data_idx,
@@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
hf_features=hf_features, hf_features=hf_features,
) )
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx return data_idx
@@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx, videos_idx,
) )
meta_idx = append_or_create_parquet_file( meta_idx, _ = append_or_create_parquet_file(
df, df,
src_path, src_path,
meta_idx, meta_idx,
@@ -501,7 +562,7 @@ def append_or_create_parquet_file(
contains_images: bool = False, contains_images: bool = False,
aggr_root: Path = None, aggr_root: Path = None,
hf_features: datasets.Features | None = None, hf_features: datasets.Features | None = None,
): ) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints. """Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files Manages file rotation when size limits are exceeded to prevent individual files
@@ -519,9 +580,11 @@ def append_or_create_parquet_file(
hf_features: Optional HuggingFace Features schema for proper image typing. hf_features: Optional HuggingFace Features schema for proper image typing.
Returns: Returns:
dict: Updated index dictionary with current chunk and file indices. tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
""" """
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists(): if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -529,14 +592,15 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path, features=hf_features) to_parquet_with_hf_images(df, dst_path, features=hf_features)
else: else:
df.to_parquet(dst_path) df.to_parquet(dst_path)
return idx return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path) src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path) dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb: if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df final_df = df
target_path = new_path target_path = new_path
@@ -555,7 +619,7 @@ def append_or_create_parquet_file(
else: else:
final_df.to_parquet(target_path) final_df.to_parquet(target_path)
return idx return idx, (dst_chunk, dst_file)
def finalize_aggregation(aggr_meta, all_metadata): def finalize_aggregation(aggr_meta, all_metadata):
+89
View File
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds) assert_dataset_iteration_works(aggr_ds)
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
This test reproduces the bug where merging datasets with multiple parquet files
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
because metadata file indices were incorrectly preserved instead of being mapped
to their actual destination files.
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
file indices to destination file indices.
"""
# Step 1: Create datasets A and B
ds_a = lerobot_dataset_factory(
root=tmp_path / "ds_a",
repo_id=f"{DUMMY_REPO_ID}_a",
total_episodes=4,
total_frames=200,
)
ds_b = lerobot_dataset_factory(
root=tmp_path / "ds_b",
repo_id=f"{DUMMY_REPO_ID}_b",
total_episodes=4,
total_frames=200,
)
# Step 2: Merge A+B into AB with small file size to force multiple files
aggregate_datasets(
repo_ids=[ds_a.repo_id, ds_b.repo_id],
roots=[ds_a.root, ds_b.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
aggr_root=tmp_path / "ds_ab",
data_files_size_in_mb=0.01, # Force file rotation
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
# Verify AB has multiple data files (file rotation occurred)
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
# Step 3: Create dataset C
ds_c = lerobot_dataset_factory(
root=tmp_path / "ds_c",
repo_id=f"{DUMMY_REPO_ID}_c",
total_episodes=2,
total_frames=100,
)
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
aggregate_datasets(
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
roots=[ds_ab.root, ds_c.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
aggr_root=tmp_path / "ds_abc",
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
# Step 5: Verify all data files referenced in metadata actually exist
for ep_idx in range(ds_abc.num_episodes):
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
assert data_file_path.exists(), (
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
"This indicates the src_to_dst mapping fix is not working correctly."
)
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
assert ds_abc.num_episodes == expected_episodes
assert ds_abc.num_frames == expected_frames
# This would raise FileNotFoundError before the fix
assert_dataset_iteration_works(ds_abc)