mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
add: aggregation util
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
|
"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
@@ -5,15 +9,76 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Aggregate all tasks datasets into a single LeRobotDataset and push it to the hub."""
|
parser = argparse.ArgumentParser(
|
||||||
task_indices = range(50)
|
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task-datasets-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aggregated-root",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path where the aggregated dataset will be written",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-tasks",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of tasks to aggregate (default: 50)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task-start-idx",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Starting task index (default: 0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-user",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aggregated-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Push the aggregated dataset to the Hugging Face Hub",
|
||||||
|
)
|
||||||
|
|
||||||
repo_ids = [f"fracapuano/behavior1k-task{i:04d}" for i in task_indices]
|
args = parser.parse_args()
|
||||||
|
|
||||||
roots = [Path(f"/fsx/francesco_capuano/behavior1k/behavior1k-task{i:04d}") for i in task_indices]
|
# Determine HF user
|
||||||
|
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
|
||||||
|
|
||||||
aggregated_root = Path("/fsx/francesco_capuano/behavior1k/behavior1k")
|
# Set default aggregated repo ID if not provided
|
||||||
aggregated_repo_id = "fracapuano/behavior1k"
|
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
|
||||||
|
|
||||||
|
# Generate task indices
|
||||||
|
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
|
||||||
|
|
||||||
|
# Generate repo IDs for individual tasks
|
||||||
|
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
|
||||||
|
|
||||||
|
# Generate local paths for individual task datasets
|
||||||
|
task_datasets_dir = Path(args.task_datasets_dir)
|
||||||
|
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
|
||||||
|
|
||||||
|
# Aggregated dataset path
|
||||||
|
aggregated_root = Path(args.aggregated_root)
|
||||||
|
|
||||||
|
print(f"🔹 Aggregating {args.num_tasks} task datasets")
|
||||||
|
print(f"Task datasets directory: {task_datasets_dir}")
|
||||||
|
print(f"Aggregated output: {aggregated_root}")
|
||||||
|
print(f"Aggregated repo ID: {aggregated_repo_id}")
|
||||||
|
|
||||||
aggregate_datasets(
|
aggregate_datasets(
|
||||||
repo_ids=repo_ids,
|
repo_ids=repo_ids,
|
||||||
@@ -22,8 +87,13 @@ def main():
|
|||||||
aggr_root=aggregated_root,
|
aggr_root=aggregated_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
|
print("✅ Aggregation complete")
|
||||||
ds.push_to_hub()
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
|
||||||
|
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
|
||||||
|
ds.push_to_hub()
|
||||||
|
print("✅ Successfully pushed to hub")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user