Files
lerobot/examples/behavior_1k/aggregate_tasks_datasets.py
T
2025-11-21 09:30:38 +00:00

101 lines
3.0 KiB
Python

"""Aggregate multiple task-specific LeRobot datasets into a single combined dataset."""
import argparse
import os
from pathlib import Path
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
parser = argparse.ArgumentParser(
description="Aggregate multiple task-specific datasets into a single LeRobot dataset"
)
parser.add_argument(
"--task-datasets-dir",
type=str,
required=True,
help="Directory containing individual task datasets (e.g., /path/to/behavior1k/)",
)
parser.add_argument(
"--aggregated-root",
type=str,
required=True,
help="Path where the aggregated dataset will be written",
)
parser.add_argument(
"--num-tasks",
type=int,
default=50,
help="Number of tasks to aggregate (default: 50)",
)
parser.add_argument(
"--task-start-idx",
type=int,
default=0,
help="Starting task index (default: 0)",
)
parser.add_argument(
"--hf-user",
type=str,
default=None,
help="HuggingFace username for repo IDs (defaults to HF_USER env var or 'lerobot')",
)
parser.add_argument(
"--aggregated-repo-id",
type=str,
default=None,
help="Repository ID for the aggregated dataset (defaults to {hf_user}/behavior1k)",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the aggregated dataset to the Hugging Face Hub",
)
args = parser.parse_args()
# Determine HF user
hf_user = args.hf_user or os.environ.get("HF_USER", "lerobot")
# Set default aggregated repo ID if not provided
aggregated_repo_id = args.aggregated_repo_id or f"{hf_user}/behavior1k"
# Generate task indices
task_indices = range(args.task_start_idx, args.task_start_idx + args.num_tasks)
# Generate repo IDs for individual tasks
repo_ids = [f"{hf_user}/behavior1k-task{i:04d}" for i in task_indices]
# Generate local paths for individual task datasets
task_datasets_dir = Path(args.task_datasets_dir)
roots = [task_datasets_dir / f"behavior1k-task{i:04d}" for i in task_indices]
# Aggregated dataset path
aggregated_root = Path(args.aggregated_root)
print(f"🔹 Aggregating {args.num_tasks} task datasets")
print(f"Task datasets directory: {task_datasets_dir}")
print(f"Aggregated output: {aggregated_root}")
print(f"Aggregated repo ID: {aggregated_repo_id}")
aggregate_datasets(
repo_ids=repo_ids,
roots=roots,
aggr_repo_id=aggregated_repo_id,
aggr_root=aggregated_root,
)
print("✅ Aggregation complete")
if args.push_to_hub:
print(f"📤 Pushing aggregated dataset to {aggregated_repo_id}")
ds = LeRobotDataset(repo_id=aggregated_repo_id, root=aggregated_root)
ds.push_to_hub()
print("✅ Successfully pushed to hub")
if __name__ == "__main__":
main()