From e38346316b5fa1cf59e92ea25a2994936f2326ae Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 2 Dec 2025 18:27:50 +0100 Subject: [PATCH] add aggregate --- examples/dataset/aggregate_egodex.py | 213 +++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 examples/dataset/aggregate_egodex.py diff --git a/examples/dataset/aggregate_egodex.py b/examples/dataset/aggregate_egodex.py new file mode 100644 index 000000000..819eef42f --- /dev/null +++ b/examples/dataset/aggregate_egodex.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Aggregate EgoDex shards into a single dataset. + +After distributed processing creates multiple shards, this script combines +them into a single unified dataset. + +Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex +""" + +import argparse +import logging +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep + + +class AggregateEgoDexDatasets(PipelineStep): + """Datatrove pipeline step for aggregating EgoDex shards.""" + + def __init__( + self, + repo_ids: list[str], + aggregated_repo_id: str, + local_dir: Path | str | None = None, + push_to_hub: bool = False, + ): + super().__init__() + self.repo_ids = repo_ids + self.aggr_repo_id = aggregated_repo_id + self.local_dir = Path(local_dir) if local_dir else None + self.push_to_hub = push_to_hub + + def run(self, data=None, rank: int = 0, world_size: int = 1): + import logging + + from lerobot.datasets.aggregate import aggregate_datasets + from lerobot.utils.utils import init_logging + + init_logging() + + # Only worker 0 performs aggregation (aggregate_datasets handles parallelism internally) + if rank == 0: + logging.info(f"Starting aggregation of {len(self.repo_ids)} shards into {self.aggr_repo_id}") + + # Build roots list if local_dir is specified + roots = None + if self.local_dir: + roots = [self.local_dir / repo_id for repo_id in self.repo_ids] + # Filter to only existing directories + roots = [r for r in roots if r.exists()] + if len(roots) != len(self.repo_ids): + logging.warning( + f"Only {len(roots)} of {len(self.repo_ids)} shard directories found. " + "Missing shards will be skipped." + ) + # Update repo_ids to match existing roots + self.repo_ids = [r.name for r in roots] + + aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None + + aggregate_datasets( + repo_ids=self.repo_ids, + aggr_repo_id=self.aggr_repo_id, + roots=roots, + aggr_root=aggr_root, + push_to_hub=self.push_to_hub, + ) + logging.info("Aggregation complete!") + else: + logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation") + + +def make_aggregate_executor( + repo_id, + num_shards, + job_name, + logs_dir, + partition, + cpus_per_task, + mem_per_cpu, + local_dir, + push_to_hub, + slurm=True, +): + """Create executor for aggregating EgoDex shards.""" + # Generate repo IDs for all shards + repo_ids = [f"{repo_id}_world_{num_shards}_rank_{rank}" for rank in range(num_shards)] + + kwargs = { + "pipeline": [ + AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + kwargs.update( + { + "job_name": job_name, + "tasks": 1, # Only need 1 task for aggregation + "workers": 1, # Only need 1 worker + "time": "24:00:00", # 24 hours for aggregation + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + kwargs.update( + { + "tasks": 1, + "workers": 1, + } + ) + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def main(): + parser = argparse.ArgumentParser( + description="Aggregate EgoDex dataset shards into a single unified dataset." + ) + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier (base name without shard suffix, e.g., pepijn/egodex-test)", + ) + parser.add_argument( + "--num-shards", + type=int, + required=True, + help="Number of shards to aggregate (must match --workers from slurm_port_egodex.py)", + ) + parser.add_argument( + "--logs-dir", + type=Path, + default=Path("logs"), + help="Path to logs directory for datatrove", + ) + parser.add_argument( + "--job-name", + type=str, + default="aggr_egodex", + help="Job name used in SLURM", + ) + parser.add_argument( + "--slurm", + type=int, + default=1, + help="Launch over SLURM. Use --slurm 0 to launch locally (for debugging)", + ) + parser.add_argument( + "--partition", + type=str, + help="SLURM partition (ideally CPU partition)", + ) + parser.add_argument( + "--cpus-per-task", + type=int, + default=16, + help="Number of CPUs for aggregation task", + ) + parser.add_argument( + "--mem-per-cpu", + type=str, + default="8G", + help="Memory per CPU for aggregation", + ) + parser.add_argument( + "--local-dir", + type=Path, + default=None, + help="Local directory where shards are stored. If not specified, uses default HF cache.", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push aggregated dataset to Hugging Face Hub after aggregation.", + ) + + args = parser.parse_args() + kwargs = vars(args) + kwargs["slurm"] = kwargs.pop("slurm") == 1 + + aggregate_executor = make_aggregate_executor(**kwargs) + aggregate_executor.run() + + +if __name__ == "__main__": + main() +