From 927c6ac3c5de362d10d2cf16eae4a6646d05471e Mon Sep 17 00:00:00 2001 From: Francesco Capuano Date: Thu, 6 Nov 2025 00:49:41 +0000 Subject: [PATCH] add: parallel, distributed aggregation of multiple datasets with a tree-based thread pool --- src/lerobot/datasets/aggregate.py | 258 +++++++++++++++++++++++++----- 1 file changed, 217 insertions(+), 41 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 870c9571e..d781196ca 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -15,8 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging import shutil +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import pandas as pd @@ -107,6 +109,7 @@ def update_meta_data( dst_meta, meta_idx, data_idx, + data_file_map, videos_idx, ): """Updates metadata DataFrame with new chunk, file, and timestamp indices. @@ -127,8 +130,25 @@ def update_meta_data( df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] - df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] - df["data/file_index"] = df["data/file_index"] + data_idx["file"] + # Remap data chunk/file indices per-source-file using the actual destination + # file chosen during data aggregation. A flat offset is incorrect when + # multiple source files are concatenated into a single destination file. + if data_file_map: + new_data_chunk = [] + new_data_file = [] + for idx in df.index: + src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location + src_file = int(df.at[idx, "data/file_index"]) # original source file location + dst_chunk, dst_file = data_file_map.get( + (src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"]) + ) + new_data_chunk.append(dst_chunk) + new_data_file.append(dst_file) + df["data/chunk_index"] = new_data_chunk + df["data/file_index"] = new_data_file + else: + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): # Store original video file indices before updating orig_chunk_col = f"videos/{key}/chunk_index" @@ -166,7 +186,7 @@ def update_meta_data( return df -def aggregate_datasets( +def _aggregate_datasets( repo_ids: list[str], aggr_repo_id: str, roots: list[Path] | None = None, @@ -175,39 +195,24 @@ def aggregate_datasets( video_files_size_in_mb: float | None = None, chunk_size: int | None = None, ): - """Aggregates multiple LeRobot datasets into a single unified dataset. + """Serial aggregation kernel: combines datasets into a destination dataset. - This is the main function that orchestrates the aggregation process by: - 1. Loading and validating all source dataset metadata - 2. Creating a new destination dataset with unified tasks - 3. Aggregating videos, data, and metadata from all source datasets - 4. Finalizing the aggregated dataset with proper statistics - - Args: - repo_ids: List of repository IDs for the datasets to aggregate. - aggr_repo_id: Repository ID for the aggregated output dataset. - roots: Optional list of root paths for the source datasets. - aggr_root: Optional root path for the aggregated dataset. - data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) - video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) - chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + This function performs a single-process aggregation. It assumes it is the + sole writer for its destination `aggr_root`. """ - logging.info("Start aggregate_datasets") - - if data_files_size_in_mb is None: - data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB - if video_files_size_in_mb is None: - video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB - if chunk_size is None: - chunk_size = DEFAULT_CHUNK_SIZE - - all_metadata = ( - [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] - if roots is None - else [ - LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False) + # Build metadata objects, supporting a per-dataset "root" that may be None. + # When root is provided we load from the local filesystem, otherwise from Hub cache. + if roots is None: + all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + else: + all_metadata = [ + ( + LeRobotDatasetMetadata(repo_id, root=root) + if root is not None + else LeRobotDatasetMetadata(repo_id) + ) + for repo_id, root in zip(repo_ids, roots, strict=False) ] - ) fps, robot_type, features = validate_all_metadata(all_metadata) video_keys = [key for key in features if features[key]["dtype"] == "video"] @@ -237,9 +242,11 @@ def aggregate_datasets( for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size) - data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size) + data_idx, data_file_map = aggregate_data( + src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size + ) - meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx) dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -248,6 +255,168 @@ def aggregate_datasets( logging.info("Aggregation complete.") +def aggregate_datasets( + repo_ids: list[str], + aggr_repo_id: str, + roots: list[Path] | None = None, + aggr_root: Path | None = None, + data_files_size_in_mb: float | None = None, + video_files_size_in_mb: float | None = None, + chunk_size: int | None = None, + num_workers: int | None = None, + tmp_root: Path | None = None, +): + """Aggregates multiple LeRobot datasets into a single unified dataset. + + This is the main function that orchestrates the aggregation process by: + 1. Loading and validating all source dataset metadata + 2. Creating a new destination dataset with unified tasks + 3. Aggregating videos, data, and metadata from all source datasets + 4. Finalizing the aggregated dataset with proper statistics + + Args: + repo_ids: List of repository IDs for the datasets to aggregate. + aggr_repo_id: Repository ID for the aggregated output dataset. + roots: Optional list of root paths for the source datasets. + aggr_root: Optional root path for the aggregated dataset. + data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) + video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) + chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + num_workers: When > 1, performs a tree-based parallel reduction using a thread pool + tmp_root: Optional base directory to store intermediate reduction outputs + """ + logging.info("Start aggregate_datasets") + + if data_files_size_in_mb is None: + data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + if video_files_size_in_mb is None: + video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + if num_workers is None or num_workers <= 1: + # Run aggregation sequentially + _aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=aggr_repo_id, + aggr_root=aggr_root, + roots=roots, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, + chunk_size=chunk_size, + ) + + # Uses a parallel fan-out/fan-in strategy when num_workers is provided + elif num_workers > 1: + # Validate across all metadata early to fail fast + all_metadata_for_validation = ( + [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + if roots is None + else [ + LeRobotDatasetMetadata(repo_id, root=root) + for repo_id, root in zip(repo_ids, roots, strict=False) + ] + ) + validate_all_metadata(all_metadata_for_validation) + + # Clamp workers to a sensible upper bound (pairs per round) + num_workers = min(num_workers, max(1, len(repo_ids) // 2)) + + # Choose a base temporary root for intermediate merge results + if tmp_root is not None: + base_tmp_root = tmp_root + elif aggr_root is not None: + base_tmp_root = aggr_root.parent / f".{aggr_repo_id}__tmp" + else: + base_tmp_root = Path.cwd() / f".{aggr_repo_id}__tmp" + base_tmp_root.mkdir(parents=True, exist_ok=True) + + current_repo_ids: list[str] = list(repo_ids) + # Always maintain a roots list aligned with repo_ids. Use None for Hub-backed inputs. + current_roots: list[Path | None] = list(roots) if roots is not None else [None] * len(repo_ids) + + try: + level = 0 + while len(current_repo_ids) > 1: + next_repo_ids: list[str] = [] + next_roots: list[Path | None] = [] + futures = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + group_index = 0 + i = 0 + while i < len(current_repo_ids): + group_repo_ids = current_repo_ids[i : i + 2] + group_roots = current_roots[i : i + 2] + + if len(group_repo_ids) == 1: + # Carry over singleton to next level + next_repo_ids.append(group_repo_ids[0]) + next_roots.append(group_roots[0]) + i += 1 + continue + + out_repo_id = f"{aggr_repo_id}__reduce_l{level}_g{group_index}" + out_root = base_tmp_root / f"reduce_l{level}_g{group_index}" + + futures.append( + executor.submit( + _aggregate_datasets, + group_repo_ids, + out_repo_id, + group_roots, + out_root, + data_files_size_in_mb, + video_files_size_in_mb, + chunk_size, + ) + ) + + next_repo_ids.append(out_repo_id) + next_roots.append(out_root) + + i += 2 + group_index += 1 + + for f in as_completed(futures): + # Bubble up any exception raised inside tasks + f.result() + + # Cleanup previous level temporary outputs that won't be used again + base_resolved = base_tmp_root.resolve() + keep_set = {nr.resolve() for nr in next_roots if nr is not None} + for prev_root in current_roots: + if prev_root is None: + continue + # Suppress per-iteration to keep cleaning other roots even if one fails + with contextlib.suppress(Exception): + pr = prev_root.resolve() + if pr not in keep_set and base_resolved in pr.parents: + shutil.rmtree(prev_root, ignore_errors=True) + + current_repo_ids = next_repo_ids + current_roots = next_roots # aligned list of Path|None after first level + level += 1 + + # Final copy/aggregation into the desired output + _aggregate_datasets( + repo_ids=current_repo_ids, + aggr_repo_id=aggr_repo_id, + roots=current_roots, + aggr_root=aggr_root, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, + chunk_size=chunk_size, + ) + finally: + # Remove all temporary reduction artifacts + with contextlib.suppress(Exception): + shutil.rmtree(base_tmp_root, ignore_errors=True) + + logging.info("Aggregation complete.") + return + + def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size): """Aggregates video chunks from a source dataset into the destination dataset. @@ -366,6 +535,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si unique_chunk_file_ids = sorted(unique_chunk_file_ids) + # Map source (chunk,file) -> destination (chunk,file) actually used during write + src_to_dst_file: dict[tuple[int, int], tuple[int, int]] = {} + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx @@ -373,7 +545,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - data_idx = append_or_create_parquet_file( + data_idx, used_chunk, used_file = append_or_create_parquet_file( df, src_path, data_idx, @@ -383,11 +555,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si contains_images=len(dst_meta.image_keys) > 0, aggr_root=dst_meta.root, ) + src_to_dst_file[(src_chunk_idx, src_file_idx)] = (used_chunk, used_file) - return data_idx + return data_idx, src_to_dst_file -def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): +def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx): """Aggregates metadata from a source dataset into the destination dataset. Reads source metadata files, updates all indices and timestamps, @@ -421,10 +594,11 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): dst_meta, meta_idx, data_idx, + data_file_map, videos_idx, ) - meta_idx = append_or_create_parquet_file( + meta_idx, _m_used_chunk, _m_used_file = append_or_create_parquet_file( df, src_path, meta_idx, @@ -478,7 +652,7 @@ def append_or_create_parquet_file( to_parquet_with_hf_images(df, dst_path) else: df.to_parquet(dst_path) - return idx + return idx, idx["chunk"], idx["file"] src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) @@ -489,17 +663,19 @@ def append_or_create_parquet_file( new_path.parent.mkdir(parents=True, exist_ok=True) final_df = df target_path = new_path + used_chunk, used_file = idx["chunk"], idx["file"] else: existing_df = pd.read_parquet(dst_path) final_df = pd.concat([existing_df, df], ignore_index=True) target_path = dst_path + used_chunk, used_file = idx["chunk"], idx["file"] if contains_images: to_parquet_with_hf_images(final_df, target_path) else: final_df.to_parquet(target_path) - return idx + return idx, used_chunk, used_file def finalize_aggregation(aggr_meta, all_metadata):