mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Added the use of aggregate.py in slurm_aggregate_shards.py
This commit is contained in:
@@ -18,20 +18,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from datatrove.executor import LocalPipelineExecutor
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
from datatrove.pipeline.base import PipelineStep
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||||
from lerobot.datasets.aggregate import validate_all_metadata
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
legacy_write_episode_stats,
|
|
||||||
legacy_write_task,
|
|
||||||
write_episode,
|
|
||||||
write_info,
|
|
||||||
)
|
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
@@ -45,155 +37,17 @@ class AggregateDatasets(PipelineStep):
|
|||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.aggr_repo_id = aggregated_repo_id
|
self.aggr_repo_id = aggregated_repo_id
|
||||||
|
|
||||||
self.create_aggr_dataset()
|
|
||||||
|
|
||||||
def create_aggr_dataset(self):
|
|
||||||
init_logging()
|
|
||||||
|
|
||||||
logging.info("Start aggregate_datasets")
|
|
||||||
|
|
||||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
|
||||||
|
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
|
||||||
|
|
||||||
# Create resulting dataset folder
|
|
||||||
aggr_meta = LeRobotDatasetMetadata.create(
|
|
||||||
repo_id=self.aggr_repo_id,
|
|
||||||
fps=fps,
|
|
||||||
robot_type=robot_type,
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Find all tasks")
|
|
||||||
# find all tasks, deduplicate them, create new task indices for each dataset
|
|
||||||
# indexed by dataset index
|
|
||||||
datasets_task_index_to_aggr_task_index = {}
|
|
||||||
aggr_task_index = 0
|
|
||||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
|
|
||||||
task_index_to_aggr_task_index = {}
|
|
||||||
|
|
||||||
for task_index, task in meta.tasks.items():
|
|
||||||
if task not in aggr_meta.task_to_task_index:
|
|
||||||
# add the task to aggr tasks mappings
|
|
||||||
aggr_meta.tasks[aggr_task_index] = task
|
|
||||||
aggr_meta.task_to_task_index[task] = aggr_task_index
|
|
||||||
aggr_task_index += 1
|
|
||||||
|
|
||||||
# add task_index anyway
|
|
||||||
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
|
||||||
|
|
||||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
|
||||||
|
|
||||||
logging.info("Prepare copy data and videos")
|
|
||||||
datasets_ep_idx_to_aggr_ep_idx = {}
|
|
||||||
datasets_aggr_episode_index_shift = {}
|
|
||||||
aggr_episode_index_shift = 0
|
|
||||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Prepare copy data and videos")):
|
|
||||||
ep_idx_to_aggr_ep_idx = {}
|
|
||||||
|
|
||||||
for episode_index in range(meta.total_episodes):
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
|
||||||
|
|
||||||
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
|
||||||
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
|
||||||
|
|
||||||
# populate episodes
|
|
||||||
for episode_index, episode_dict in meta.episodes.items():
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
episode_dict["episode_index"] = aggr_episode_index
|
|
||||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
|
||||||
|
|
||||||
# populate episodes_stats
|
|
||||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
|
||||||
|
|
||||||
# populate info
|
|
||||||
aggr_meta.info["total_episodes"] += meta.total_episodes
|
|
||||||
aggr_meta.info["total_frames"] += meta.total_frames
|
|
||||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
|
||||||
|
|
||||||
aggr_episode_index_shift += meta.total_episodes
|
|
||||||
|
|
||||||
logging.info("Write meta data")
|
|
||||||
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
|
||||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
|
||||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
|
||||||
|
|
||||||
# create a new episodes jsonl with updated episode_index using write_episode
|
|
||||||
for episode_dict in tqdm.tqdm(aggr_meta.episodes.values(), desc="Write episodes"):
|
|
||||||
write_episode(episode_dict, aggr_meta.root)
|
|
||||||
|
|
||||||
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
|
||||||
for episode_index, episode_stats in tqdm.tqdm(
|
|
||||||
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
|
|
||||||
):
|
|
||||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
|
||||||
|
|
||||||
# create a new task jsonl with updated episode_index using write_task
|
|
||||||
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
|
|
||||||
legacy_write_task(task_index, task, aggr_meta.root)
|
|
||||||
|
|
||||||
write_info(aggr_meta.info, aggr_meta.root)
|
|
||||||
|
|
||||||
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
|
||||||
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
|
||||||
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
|
||||||
|
|
||||||
logging.info("Meta data done writing!")
|
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
import logging
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import get_update_episode_and_task_func
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
||||||
from lerobot.utils.utils import init_logging
|
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
aggr_meta = LeRobotDatasetMetadata(self.aggr_repo_id)
|
# Since aggregate_datasets already handles parallel processing internally,
|
||||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
# we only need one worker to run the entire aggregation
|
||||||
|
if rank == 0:
|
||||||
if world_size != len(all_metadata):
|
logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
|
||||||
raise ValueError()
|
aggregate_datasets(self.repo_ids, self.aggr_repo_id)
|
||||||
|
logging.info("Aggregation complete!")
|
||||||
dataset_index = rank
|
else:
|
||||||
meta = all_metadata[dataset_index]
|
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||||
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
|
|
||||||
|
|
||||||
logging.info("Copy data")
|
|
||||||
for episode_index in range(meta.total_episodes):
|
|
||||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
|
||||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
|
||||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
|
||||||
|
|
||||||
# update episode_index and task_index
|
|
||||||
df = pd.read_parquet(data_path)
|
|
||||||
update_row_func = get_update_episode_and_task_func(
|
|
||||||
aggr_episode_index_shift, self.datasets_task_index_to_aggr_task_index[dataset_index]
|
|
||||||
)
|
|
||||||
df = df.apply(update_row_func, axis=1)
|
|
||||||
|
|
||||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
df.to_parquet(aggr_data_path)
|
|
||||||
|
|
||||||
logging.info("Copy videos")
|
|
||||||
for episode_index in range(meta.total_episodes):
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
for vid_key in meta.video_keys:
|
|
||||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
|
||||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
|
||||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.copy(video_path, aggr_video_path)
|
|
||||||
|
|
||||||
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
|
||||||
# subprocess.Popen(copy_command, shell=True)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
def make_aggregate_executor(
|
def make_aggregate_executor(
|
||||||
@@ -207,11 +61,12 @@ def make_aggregate_executor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slurm:
|
if slurm:
|
||||||
|
# For aggregation, we only need 1 task since aggregate_datasets handles everything
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
{
|
{
|
||||||
"job_name": job_name,
|
"job_name": job_name,
|
||||||
"tasks": DROID_SHARDS,
|
"tasks": 1, # Only need 1 task for aggregation
|
||||||
"workers": workers,
|
"workers": 1, # Only need 1 worker
|
||||||
"time": "08:00:00",
|
"time": "08:00:00",
|
||||||
"partition": partition,
|
"partition": partition,
|
||||||
"cpus_per_task": cpus_per_task,
|
"cpus_per_task": cpus_per_task,
|
||||||
@@ -222,7 +77,7 @@ def make_aggregate_executor(
|
|||||||
else:
|
else:
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
{
|
{
|
||||||
"tasks": DROID_SHARDS,
|
"tasks": 1,
|
||||||
"workers": 1,
|
"workers": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -237,7 +92,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logs-dir",
|
"--logs-dir",
|
||||||
@@ -259,8 +114,8 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--workers",
|
"--workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2048,
|
default=1, # Changed default to 1 since aggregation doesn't need multiple workers
|
||||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
help="Number of slurm workers. For aggregation, this should be 1.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--partition",
|
"--partition",
|
||||||
|
|||||||
Reference in New Issue
Block a user