add: aggregation util

This commit is contained in:
fracapuano
2025-11-21 09:27:47 +00:00
parent f6755dbf20
commit ca1841f5fc
@@ -1,3 +1,7 @@
"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
import argparse
import os
from pathlib import Path
from lerobot.datasets.aggregate import aggregate_datasets
@@ -5,15 +9,76 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
"""Aggregate all tasks datasets into a single LeRobotDataset and push it to the hub."""
task_indices = range(50)
parser = argparse.ArgumentParser(
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
)
parser.add_argument(
"--task-datasets-dir",
type=str,
required=True,
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
)
parser.add_argument(
"--aggregated-root",
type=str,
required=True,
help="Path where the aggregated dataset will be written",
)
parser.add_argument(
"--num-tasks",
type=int,
default=50,
help="Number of tasks to aggregate (default: 50)",
)
parser.add_argument(
"--task-start-idx",
type=int,
default=0,
help="Starting task index (default: 0)",
)
parser.add_argument(
"--hf-user",
type=str,
default=None,
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
)
parser.add_argument(
"--aggregated-repo-id",
type=str,
default=None,
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the aggregated dataset to the Hugging Face Hub",
)
repo_ids = [f"fracapuano/behavior1k-task{i:04d}" for i in task_indices]
args = parser.parse_args()
roots = [Path(f"/fsx/francesco_capuano/behavior1k/behavior1k-task{i:04d}") for i in task_indices]
# Determine HF user
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
aggregated_root = Path("/fsx/francesco_capuano/behavior1k/behavior1k")
aggregated_repo_id = "fracapuano/behavior1k"
# Set default aggregated repo ID if not provided
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
# Generate task indices
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
# Generate repo IDs for individual tasks
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
# Generate local paths for individual task datasets
task_datasets_dir = Path(args.task_datasets_dir)
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
# Aggregated dataset path
aggregated_root = Path(args.aggregated_root)
print(f"🔹 Aggregating {args.num_tasks} task datasets")
print(f"Task datasets directory: {task_datasets_dir}")
print(f"Aggregated output: {aggregated_root}")
print(f"Aggregated repo ID: {aggregated_repo_id}")
aggregate_datasets(
repo_ids=repo_ids,
@@ -22,8 +87,13 @@ def main():
aggr_root=aggregated_root,
)
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
ds.push_to_hub()
print("✅ Aggregation complete")
if args.push_to_hub:
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
ds.push_to_hub()
print("✅ Successfully pushed to hub")
if __name__ == "__main__":