From fa3919a0ff06d1b60828a8cfc1d77c9860b6fcb0 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 2 Dec 2025 18:30:11 +0100 Subject: [PATCH] add push to hub --- examples/dataset/aggregate_egodex.py | 31 +++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/examples/dataset/aggregate_egodex.py b/examples/dataset/aggregate_egodex.py index 819eef42f..eee02f83d 100644 --- a/examples/dataset/aggregate_egodex.py +++ b/examples/dataset/aggregate_egodex.py @@ -52,6 +52,7 @@ class AggregateEgoDexDatasets(PipelineStep): import logging from lerobot.datasets.aggregate import aggregate_datasets + from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.utils import init_logging init_logging() @@ -65,14 +66,22 @@ class AggregateEgoDexDatasets(PipelineStep): if self.local_dir: roots = [self.local_dir / repo_id for repo_id in self.repo_ids] # Filter to only existing directories - roots = [r for r in roots if r.exists()] - if len(roots) != len(self.repo_ids): + existing_roots = [r for r in roots if r.exists()] + if len(existing_roots) != len(self.repo_ids): logging.warning( - f"Only {len(roots)} of {len(self.repo_ids)} shard directories found. " + f"Only {len(existing_roots)} of {len(self.repo_ids)} shard directories found. " "Missing shards will be skipped." ) # Update repo_ids to match existing roots - self.repo_ids = [r.name for r in roots] + existing_repo_ids = [ + repo_id for repo_id, r in zip(self.repo_ids, roots, strict=False) if r.exists() + ] + roots = existing_roots + self.repo_ids = existing_repo_ids + + if len(self.repo_ids) == 0: + logging.error("No shard directories found. Nothing to aggregate.") + return aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None @@ -81,9 +90,21 @@ class AggregateEgoDexDatasets(PipelineStep): aggr_repo_id=self.aggr_repo_id, roots=roots, aggr_root=aggr_root, - push_to_hub=self.push_to_hub, ) logging.info("Aggregation complete!") + + # Push to Hugging Face Hub if requested + if self.push_to_hub: + logging.info(f"Pushing {self.aggr_repo_id} to Hugging Face Hub...") + dataset = LeRobotDataset( + repo_id=self.aggr_repo_id, + root=aggr_root, + ) + dataset.push_to_hub( + tags=["egodex", "hand", "dexterous", "lerobot"], + license="cc-by-nc-nd-4.0", + ) + logging.info("Push to hub complete!") else: logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")