Added the use of aggregate.py in slurm_aggregate_shards.py

This commit is contained in:
Michel Aractingi
2025-07-07 13:51:08 +02:00
parent 4a466d94b6
commit 18209e6194
@@ -18,20 +18,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import tqdm
from datatrove.executor import LocalPipelineExecutor from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import validate_all_metadata from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import (
legacy_write_episode_stats,
legacy_write_task,
write_episode,
write_info,
)
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
@@ -45,155 +37,17 @@ class AggregateDatasets(PipelineStep):
self.repo_ids = repo_ids self.repo_ids = repo_ids
self.aggr_repo_id = aggregated_repo_id self.aggr_repo_id = aggregated_repo_id
self.create_aggr_dataset()
def create_aggr_dataset(self):
init_logging()
logging.info("Start aggregate_datasets")
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
fps, robot_type, features = validate_all_metadata(all_metadata)
# Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create(
repo_id=self.aggr_repo_id,
fps=fps,
robot_type=robot_type,
features=features,
)
logging.info("Find all tasks")
# find all tasks, deduplicate them, create new task indices for each dataset
# indexed by dataset index
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
task_index_to_aggr_task_index = {}
for task_index, task in meta.tasks.items():
if task not in aggr_meta.task_to_task_index:
# add the task to aggr tasks mappings
aggr_meta.tasks[aggr_task_index] = task
aggr_meta.task_to_task_index[task] = aggr_task_index
aggr_task_index += 1
# add task_index anyway
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
logging.info("Prepare copy data and videos")
datasets_ep_idx_to_aggr_ep_idx = {}
datasets_aggr_episode_index_shift = {}
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Prepare copy data and videos")):
ep_idx_to_aggr_ep_idx = {}
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict
# populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
# populate info
aggr_meta.info["total_episodes"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logging.info("Write meta data")
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode
for episode_dict in tqdm.tqdm(aggr_meta.episodes.values(), desc="Write episodes"):
write_episode(episode_dict, aggr_meta.root)
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
for episode_index, episode_stats in tqdm.tqdm(
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
):
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
legacy_write_task(task_index, task, aggr_meta.root)
write_info(aggr_meta.info, aggr_meta.root)
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
logging.info("Meta data done writing!")
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
import shutil
import pandas as pd
from lerobot.datasets.aggregate import get_update_episode_and_task_func
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.utils.utils import init_logging
init_logging() init_logging()
aggr_meta = LeRobotDatasetMetadata(self.aggr_repo_id) # Since aggregate_datasets already handles parallel processing internally,
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids] # we only need one worker to run the entire aggregation
if rank == 0:
if world_size != len(all_metadata): logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
raise ValueError() aggregate_datasets(self.repo_ids, self.aggr_repo_id)
logging.info("Aggregation complete!")
dataset_index = rank else:
meta = all_metadata[dataset_index] logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
logging.info("Copy data")
for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
# update episode_index and task_index
df = pd.read_parquet(data_path)
update_row_func = get_update_episode_and_task_func(
aggr_episode_index_shift, self.datasets_task_index_to_aggr_task_index[dataset_index]
)
df = df.apply(update_row_func, axis=1)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_data_path)
logging.info("Copy videos")
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
# copy_command = f"cp {video_path} {aggr_video_path} &"
# subprocess.Popen(copy_command, shell=True)
logging.info("Done!")
def make_aggregate_executor( def make_aggregate_executor(
@@ -207,11 +61,12 @@ def make_aggregate_executor(
} }
if slurm: if slurm:
# For aggregation, we only need 1 task since aggregate_datasets handles everything
kwargs.update( kwargs.update(
{ {
"job_name": job_name, "job_name": job_name,
"tasks": DROID_SHARDS, "tasks": 1, # Only need 1 task for aggregation
"workers": workers, "workers": 1, # Only need 1 worker
"time": "08:00:00", "time": "08:00:00",
"partition": partition, "partition": partition,
"cpus_per_task": cpus_per_task, "cpus_per_task": cpus_per_task,
@@ -222,7 +77,7 @@ def make_aggregate_executor(
else: else:
kwargs.update( kwargs.update(
{ {
"tasks": DROID_SHARDS, "tasks": 1,
"workers": 1, "workers": 1,
} }
) )
@@ -237,7 +92,7 @@ def main():
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
type=str, type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.", help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
) )
parser.add_argument( parser.add_argument(
"--logs-dir", "--logs-dir",
@@ -259,8 +114,8 @@ def main():
parser.add_argument( parser.add_argument(
"--workers", "--workers",
type=int, type=int,
default=2048, default=1, # Changed default to 1 since aggregation doesn't need multiple workers
help="Number of slurm workers. It should be less than the maximum number of shards.", help="Number of slurm workers. For aggregation, this should be 1.",
) )
parser.add_argument( parser.add_argument(
"--partition", "--partition",