mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
|
|
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from lerobot.datasets.aggregate import aggregate_datasets
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--task-datasets-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
|
|
)
|
|
parser.add_argument(
|
|
"--aggregated-root",
|
|
type=str,
|
|
required=True,
|
|
help="Path where the aggregated dataset will be written",
|
|
)
|
|
parser.add_argument(
|
|
"--num-tasks",
|
|
type=int,
|
|
default=50,
|
|
help="Number of tasks to aggregate (default: 50)",
|
|
)
|
|
parser.add_argument(
|
|
"--task-start-idx",
|
|
type=int,
|
|
default=0,
|
|
help="Starting task index (default: 0)",
|
|
)
|
|
parser.add_argument(
|
|
"--hf-user",
|
|
type=str,
|
|
default=None,
|
|
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
|
|
)
|
|
parser.add_argument(
|
|
"--aggregated-repo-id",
|
|
type=str,
|
|
default=None,
|
|
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
|
|
)
|
|
parser.add_argument(
|
|
"--push-to-hub",
|
|
action="store_true",
|
|
help="Push the aggregated dataset to the Hugging Face Hub",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determine HF user
|
|
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
|
|
|
|
# Set default aggregated repo ID if not provided
|
|
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
|
|
|
|
# Generate task indices
|
|
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
|
|
|
|
# Generate repo IDs for individual tasks
|
|
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
|
|
|
|
# Generate local paths for individual task datasets
|
|
task_datasets_dir = Path(args.task_datasets_dir)
|
|
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
|
|
|
|
# Aggregated dataset path
|
|
aggregated_root = Path(args.aggregated_root)
|
|
|
|
print(f"🔹 Aggregating {args.num_tasks} task datasets")
|
|
print(f"Task datasets directory: {task_datasets_dir}")
|
|
print(f"Aggregated output: {aggregated_root}")
|
|
print(f"Aggregated repo ID: {aggregated_repo_id}")
|
|
|
|
aggregate_datasets(
|
|
repo_ids=repo_ids,
|
|
roots=roots,
|
|
aggr_repo_id=aggregated_repo_id,
|
|
aggr_root=aggregated_root,
|
|
)
|
|
|
|
print("✅ Aggregation complete")
|
|
|
|
if args.push_to_hub:
|
|
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
|
|
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
|
|
ds.push_to_hub()
|
|
print("✅ Successfully pushed to hub")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|