Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)

* Fix aggeregation of datasets when subdatasets are already a result of a previous merge

* docstring

* respond to copilot review + add regression test

* Remove unnecessary int conversion for indicies
This commit is contained in:
Michel Aractingi
2026-01-28 13:31:27 +01:00
committed by GitHub
parent f6b1c39b78
commit 736b43f3cf
2 changed files with 171 additions and 18 deletions
+89
View File
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
This test reproduces the bug where merging datasets with multiple parquet files
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
because metadata file indices were incorrectly preserved instead of being mapped
to their actual destination files.
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
file indices to destination file indices.
"""
# Step 1: Create datasets A and B
ds_a = lerobot_dataset_factory(
root=tmp_path / "ds_a",
repo_id=f"{DUMMY_REPO_ID}_a",
total_episodes=4,
total_frames=200,
)
ds_b = lerobot_dataset_factory(
root=tmp_path / "ds_b",
repo_id=f"{DUMMY_REPO_ID}_b",
total_episodes=4,
total_frames=200,
)
# Step 2: Merge A+B into AB with small file size to force multiple files
aggregate_datasets(
repo_ids=[ds_a.repo_id, ds_b.repo_id],
roots=[ds_a.root, ds_b.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
aggr_root=tmp_path / "ds_ab",
data_files_size_in_mb=0.01, # Force file rotation
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
# Verify AB has multiple data files (file rotation occurred)
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
# Step 3: Create dataset C
ds_c = lerobot_dataset_factory(
root=tmp_path / "ds_c",
repo_id=f"{DUMMY_REPO_ID}_c",
total_episodes=2,
total_frames=100,
)
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
aggregate_datasets(
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
roots=[ds_ab.root, ds_c.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
aggr_root=tmp_path / "ds_abc",
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
# Step 5: Verify all data files referenced in metadata actually exist
for ep_idx in range(ds_abc.num_episodes):
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
assert data_file_path.exists(), (
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
"This indicates the src_to_dst mapping fix is not working correctly."
)
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
assert ds_abc.num_episodes == expected_episodes
assert ds_abc.num_frames == expected_frames
# This would raise FileNotFoundError before the fix
assert_dataset_iteration_works(ds_abc)