add push to hub

This commit is contained in:
Pepijn
2025-12-02 18:30:11 +01:00
parent e38346316b
commit fa3919a0ff
+26 -5
View File
@@ -52,6 +52,7 @@ class AggregateEgoDexDatasets(PipelineStep):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
init_logging()
@@ -65,14 +66,22 @@ class AggregateEgoDexDatasets(PipelineStep):
if self.local_dir:
roots = [self.local_dir / repo_id for repo_id in self.repo_ids]
# Filter to only existing directories
roots = [r for r in roots if r.exists()]
if len(roots) != len(self.repo_ids):
existing_roots = [r for r in roots if r.exists()]
if len(existing_roots) != len(self.repo_ids):
logging.warning(
f"Only {len(roots)} of {len(self.repo_ids)} shard directories found. "
f"Only {len(existing_roots)} of {len(self.repo_ids)} shard directories found. "
"Missing shards will be skipped."
)
# Update repo_ids to match existing roots
self.repo_ids = [r.name for r in roots]
existing_repo_ids = [
repo_id for repo_id, r in zip(self.repo_ids, roots, strict=False) if r.exists()
]
roots = existing_roots
self.repo_ids = existing_repo_ids
if len(self.repo_ids) == 0:
logging.error("No shard directories found. Nothing to aggregate.")
return
aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None
@@ -81,9 +90,21 @@ class AggregateEgoDexDatasets(PipelineStep):
aggr_repo_id=self.aggr_repo_id,
roots=roots,
aggr_root=aggr_root,
push_to_hub=self.push_to_hub,
)
logging.info("Aggregation complete!")
# Push to Hugging Face Hub if requested
if self.push_to_hub:
logging.info(f"Pushing {self.aggr_repo_id} to Hugging Face Hub...")
dataset = LeRobotDataset(
repo_id=self.aggr_repo_id,
root=aggr_root,
)
dataset.push_to_hub(
tags=["egodex", "hand", "dexterous", "lerobot"],
license="cc-by-nc-nd-4.0",
)
logging.info("Push to hub complete!")
else:
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")