mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)
* Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies
This commit is contained in:
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
|
||||
|
||||
This test reproduces the bug where merging datasets with multiple parquet files
|
||||
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
|
||||
because metadata file indices were incorrectly preserved instead of being mapped
|
||||
to their actual destination files.
|
||||
|
||||
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
|
||||
file indices to destination file indices.
|
||||
"""
|
||||
# Step 1: Create datasets A and B
|
||||
ds_a = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_a",
|
||||
repo_id=f"{DUMMY_REPO_ID}_a",
|
||||
total_episodes=4,
|
||||
total_frames=200,
|
||||
)
|
||||
ds_b = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_b",
|
||||
repo_id=f"{DUMMY_REPO_ID}_b",
|
||||
total_episodes=4,
|
||||
total_frames=200,
|
||||
)
|
||||
|
||||
# Step 2: Merge A+B into AB with small file size to force multiple files
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_a.repo_id, ds_b.repo_id],
|
||||
roots=[ds_a.root, ds_b.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
|
||||
aggr_root=tmp_path / "ds_ab",
|
||||
data_files_size_in_mb=0.01, # Force file rotation
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
|
||||
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
|
||||
|
||||
# Verify AB has multiple data files (file rotation occurred)
|
||||
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
|
||||
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
|
||||
|
||||
# Step 3: Create dataset C
|
||||
ds_c = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_c",
|
||||
repo_id=f"{DUMMY_REPO_ID}_c",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
)
|
||||
|
||||
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
|
||||
roots=[ds_ab.root, ds_c.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
|
||||
aggr_root=tmp_path / "ds_abc",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
|
||||
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
|
||||
|
||||
# Step 5: Verify all data files referenced in metadata actually exist
|
||||
for ep_idx in range(ds_abc.num_episodes):
|
||||
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
|
||||
assert data_file_path.exists(), (
|
||||
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
|
||||
"This indicates the src_to_dst mapping fix is not working correctly."
|
||||
)
|
||||
|
||||
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
|
||||
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
|
||||
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
|
||||
|
||||
assert ds_abc.num_episodes == expected_episodes
|
||||
assert ds_abc.num_frames == expected_frames
|
||||
|
||||
# This would raise FileNotFoundError before the fix
|
||||
assert_dataset_iteration_works(ds_abc)
|
||||
|
||||
Reference in New Issue
Block a user