mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)
* Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies
This commit is contained in:
@@ -116,6 +116,9 @@ def update_meta_data(
|
|||||||
Adjusts all indices and timestamps to account for previously aggregated
|
Adjusts all indices and timestamps to account for previously aggregated
|
||||||
data and videos in the destination dataset.
|
data and videos in the destination dataset.
|
||||||
|
|
||||||
|
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
||||||
|
to correctly map source file indices to their destination locations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: DataFrame containing the metadata to be updated.
|
df: DataFrame containing the metadata to be updated.
|
||||||
dst_meta: Destination dataset metadata.
|
dst_meta: Destination dataset metadata.
|
||||||
@@ -129,6 +132,48 @@ def update_meta_data(
|
|||||||
|
|
||||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||||
|
|
||||||
|
# Update data file indices using source-to-destination mapping
|
||||||
|
# This is critical for handling datasets that are already results of a merge
|
||||||
|
data_src_to_dst = data_idx.get("src_to_dst", {})
|
||||||
|
if data_src_to_dst:
|
||||||
|
# Store original indices for lookup
|
||||||
|
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
||||||
|
df["_orig_data_file"] = df["data/file_index"].copy()
|
||||||
|
|
||||||
|
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
||||||
|
# This is much faster than per-row iteration for large metadata tables
|
||||||
|
mapping_index = pd.MultiIndex.from_tuples(
|
||||||
|
list(data_src_to_dst.keys()),
|
||||||
|
names=["chunk_index", "file_index"],
|
||||||
|
)
|
||||||
|
mapping_values = list(data_src_to_dst.values())
|
||||||
|
mapping_df = pd.DataFrame(
|
||||||
|
mapping_values,
|
||||||
|
index=mapping_index,
|
||||||
|
columns=["dst_chunk", "dst_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct a MultiIndex for each row based on original data indices
|
||||||
|
row_index = pd.MultiIndex.from_arrays(
|
||||||
|
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
||||||
|
names=["chunk_index", "file_index"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Align mapping to rows; missing keys fall back to the default destination
|
||||||
|
reindexed = mapping_df.reindex(row_index)
|
||||||
|
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
||||||
|
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign mapped destination indices back to the DataFrame
|
||||||
|
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
||||||
|
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
||||||
|
|
||||||
|
# Clean up temporary columns
|
||||||
|
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
||||||
|
else:
|
||||||
|
# Fallback to simple offset (backward compatibility for single-file sources)
|
||||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
@@ -146,8 +191,7 @@ def update_meta_data(
|
|||||||
if src_to_dst:
|
if src_to_dst:
|
||||||
# Map each episode to its correct destination file and apply offset
|
# Map each episode to its correct destination file and apply offset
|
||||||
for idx in df.index:
|
for idx in df.index:
|
||||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
|
||||||
|
|
||||||
# Get destination chunk/file for this source file
|
# Get destination chunk/file for this source file
|
||||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||||
@@ -163,8 +207,7 @@ def update_meta_data(
|
|||||||
df[orig_chunk_col] = video_idx["chunk"]
|
df[orig_chunk_col] = video_idx["chunk"]
|
||||||
df[orig_file_col] = video_idx["file"]
|
df[orig_file_col] = video_idx["file"]
|
||||||
for idx in df.index:
|
for idx in df.index:
|
||||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
|
||||||
offset = src_to_offset.get(src_key, 0)
|
offset = src_to_offset.get(src_key, 0)
|
||||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||||
@@ -262,6 +305,10 @@ def aggregate_datasets(
|
|||||||
|
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||||
|
|
||||||
|
# Clear the src_to_dst mapping after processing each source dataset
|
||||||
|
# to avoid interference between different source datasets
|
||||||
|
data_idx.pop("src_to_dst", None)
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||||
|
|
||||||
@@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
dst_file_durations = video_idx["dst_file_durations"]
|
dst_file_durations = video_idx["dst_file_durations"]
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
# Convert to Python int to ensure consistent dict keys
|
|
||||||
src_chunk_idx = int(src_chunk_idx)
|
|
||||||
src_file_idx = int(src_file_idx)
|
|
||||||
|
|
||||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
video_key=key,
|
video_key=key,
|
||||||
chunk_index=src_chunk_idx,
|
chunk_index=src_chunk_idx,
|
||||||
@@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
Reads source data files, updates indices to match the aggregated dataset,
|
Reads source data files, updates indices to match the aggregated dataset,
|
||||||
and writes them to the destination with proper file rotation.
|
and writes them to the destination with proper file rotation.
|
||||||
|
|
||||||
|
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
||||||
|
which is critical for correctly updating episode metadata when source datasets
|
||||||
|
have multiple data files (e.g., from a previous merge operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src_meta: Source dataset metadata.
|
src_meta: Source dataset metadata.
|
||||||
dst_meta: Destination dataset metadata.
|
dst_meta: Destination dataset metadata.
|
||||||
data_idx: Dictionary tracking data chunk and file indices.
|
data_idx: Dictionary tracking data chunk and file indices.
|
||||||
|
data_files_size_in_mb: Maximum size for data files in MB.
|
||||||
|
chunk_size: Maximum number of files per chunk.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated data_idx with current chunk and file indices.
|
dict: Updated data_idx with current chunk and file indices.
|
||||||
@@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
# retrieve features schema for proper image typing in parquet
|
# retrieve features schema for proper image typing in parquet
|
||||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||||
|
|
||||||
|
# Track source to destination file mapping for metadata update
|
||||||
|
# This is critical for handling datasets that are already results of a merge
|
||||||
|
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
@@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
df = pd.read_parquet(src_path)
|
df = pd.read_parquet(src_path)
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
df = update_data_df(df, src_meta, dst_meta)
|
||||||
|
|
||||||
data_idx = append_or_create_parquet_file(
|
# Write data and get the actual destination file it was written to
|
||||||
|
# This avoids duplicating the rotation logic here
|
||||||
|
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
data_idx,
|
data_idx,
|
||||||
@@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
hf_features=hf_features,
|
hf_features=hf_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Record the mapping from source to actual destination
|
||||||
|
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
||||||
|
|
||||||
|
# Add the mapping to data_idx for use in metadata update
|
||||||
|
data_idx["src_to_dst"] = src_to_dst
|
||||||
|
|
||||||
return data_idx
|
return data_idx
|
||||||
|
|
||||||
|
|
||||||
@@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
videos_idx,
|
videos_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
meta_idx = append_or_create_parquet_file(
|
meta_idx, _ = append_or_create_parquet_file(
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
meta_idx,
|
meta_idx,
|
||||||
@@ -501,7 +562,7 @@ def append_or_create_parquet_file(
|
|||||||
contains_images: bool = False,
|
contains_images: bool = False,
|
||||||
aggr_root: Path = None,
|
aggr_root: Path = None,
|
||||||
hf_features: datasets.Features | None = None,
|
hf_features: datasets.Features | None = None,
|
||||||
):
|
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||||
|
|
||||||
Manages file rotation when size limits are exceeded to prevent individual files
|
Manages file rotation when size limits are exceeded to prevent individual files
|
||||||
@@ -519,9 +580,11 @@ def append_or_create_parquet_file(
|
|||||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated index dictionary with current chunk and file indices.
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||||
|
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||||
"""
|
"""
|
||||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
|
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -529,14 +592,15 @@ def append_or_create_parquet_file(
|
|||||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
df.to_parquet(dst_path)
|
||||||
return idx
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
src_size = get_parquet_file_size_in_mb(src_path)
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
if dst_size + src_size >= max_mb:
|
if dst_size + src_size >= max_mb:
|
||||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
|
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
final_df = df
|
final_df = df
|
||||||
target_path = new_path
|
target_path = new_path
|
||||||
@@ -555,7 +619,7 @@ def append_or_create_parquet_file(
|
|||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
return idx
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
|
|
||||||
def finalize_aggregation(aggr_meta, all_metadata):
|
def finalize_aggregation(aggr_meta, all_metadata):
|
||||||
|
|||||||
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
|||||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||||
|
|
||||||
assert_dataset_iteration_works(aggr_ds)
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
|
||||||
|
|
||||||
|
This test reproduces the bug where merging datasets with multiple parquet files
|
||||||
|
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
|
||||||
|
because metadata file indices were incorrectly preserved instead of being mapped
|
||||||
|
to their actual destination files.
|
||||||
|
|
||||||
|
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
|
||||||
|
file indices to destination file indices.
|
||||||
|
"""
|
||||||
|
# Step 1: Create datasets A and B
|
||||||
|
ds_a = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "ds_a",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_a",
|
||||||
|
total_episodes=4,
|
||||||
|
total_frames=200,
|
||||||
|
)
|
||||||
|
ds_b = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "ds_b",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_b",
|
||||||
|
total_episodes=4,
|
||||||
|
total_frames=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Merge A+B into AB with small file size to force multiple files
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_a.repo_id, ds_b.repo_id],
|
||||||
|
roots=[ds_a.root, ds_b.root],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
|
||||||
|
aggr_root=tmp_path / "ds_ab",
|
||||||
|
data_files_size_in_mb=0.01, # Force file rotation
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
|
||||||
|
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
|
||||||
|
|
||||||
|
# Verify AB has multiple data files (file rotation occurred)
|
||||||
|
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
|
||||||
|
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
|
||||||
|
|
||||||
|
# Step 3: Create dataset C
|
||||||
|
ds_c = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "ds_c",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_c",
|
||||||
|
total_episodes=2,
|
||||||
|
total_frames=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
|
||||||
|
roots=[ds_ab.root, ds_c.root],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
|
||||||
|
aggr_root=tmp_path / "ds_abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
|
||||||
|
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
|
||||||
|
|
||||||
|
# Step 5: Verify all data files referenced in metadata actually exist
|
||||||
|
for ep_idx in range(ds_abc.num_episodes):
|
||||||
|
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
|
||||||
|
assert data_file_path.exists(), (
|
||||||
|
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
|
||||||
|
"This indicates the src_to_dst mapping fix is not working correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
|
||||||
|
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
|
||||||
|
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
|
||||||
|
|
||||||
|
assert ds_abc.num_episodes == expected_episodes
|
||||||
|
assert ds_abc.num_frames == expected_frames
|
||||||
|
|
||||||
|
# This would raise FileNotFoundError before the fix
|
||||||
|
assert_dataset_iteration_works(ds_abc)
|
||||||
|
|||||||
Reference in New Issue
Block a user