mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
add: aggregation util
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
@@ -5,15 +9,76 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
"""Aggregate all tasks datasets into a single LeRobotDataset and push it to the hub."""
|
||||
task_indices = range(50)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-datasets-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregated-root",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path where the aggregated dataset will be written",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tasks",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of tasks to aggregate (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-start-idx",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Starting task index (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-user",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregated-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the aggregated dataset to the Hugging Face Hub",
|
||||
)
|
||||
|
||||
repo_ids = [f"fracapuano/behavior1k-task{i:04d}" for i in task_indices]
|
||||
args = parser.parse_args()
|
||||
|
||||
roots = [Path(f"/fsx/francesco_capuano/behavior1k/behavior1k-task{i:04d}") for i in task_indices]
|
||||
# Determine HF user
|
||||
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
|
||||
|
||||
aggregated_root = Path("/fsx/francesco_capuano/behavior1k/behavior1k")
|
||||
aggregated_repo_id = "fracapuano/behavior1k"
|
||||
# Set default aggregated repo ID if not provided
|
||||
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
|
||||
|
||||
# Generate task indices
|
||||
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
|
||||
|
||||
# Generate repo IDs for individual tasks
|
||||
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
|
||||
|
||||
# Generate local paths for individual task datasets
|
||||
task_datasets_dir = Path(args.task_datasets_dir)
|
||||
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
|
||||
|
||||
# Aggregated dataset path
|
||||
aggregated_root = Path(args.aggregated_root)
|
||||
|
||||
print(f"🔹 Aggregating {args.num_tasks} task datasets")
|
||||
print(f"Task datasets directory: {task_datasets_dir}")
|
||||
print(f"Aggregated output: {aggregated_root}")
|
||||
print(f"Aggregated repo ID: {aggregated_repo_id}")
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=repo_ids,
|
||||
@@ -22,8 +87,13 @@ def main():
|
||||
aggr_root=aggregated_root,
|
||||
)
|
||||
|
||||
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
|
||||
ds.push_to_hub()
|
||||
print("✅ Aggregation complete")
|
||||
|
||||
if args.push_to_hub:
|
||||
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
|
||||
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
|
||||
ds.push_to_hub()
|
||||
print("✅ Successfully pushed to hub")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user