import argparse import os import re import shutil from pathlib import Path import pandas as pd import ray from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor from datatrove.pipeline.base import PipelineStep from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.aggregate import ( aggregate_data, aggregate_metadata, aggregate_stats, aggregate_videos, validate_all_metadata, ) from lerobot.datasets.io_utils import write_info, write_stats, write_tasks from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB, ) from libero_utils.config import LIBERO_FEATURES from libero_utils.libero_utils import load_local_episodes from ray.runtime_env import RuntimeEnv from tqdm import tqdm def setup_logger(): import sys from datatrove.utils.logging import logger logger.remove() logger.add(sys.stdout, level="INFO", colorize=True) return logger class SaveLerobotDataset(PipelineStep): name = "Save Temp LerobotDataset" type = "libero2lerobot" def __init__(self, tasks: list[tuple[Path, Path, str]]): super().__init__() self.tasks = tasks def run(self, data=None, rank: int = 0, world_size: int = 1): logger = setup_logger() input_h5, output_path, task_instruction = self.tasks[rank] if output_path.exists(): shutil.rmtree(output_path) dataset = LeRobotDataset.create( repo_id=f"{input_h5.parent.name}/{input_h5.name}", root=output_path, fps=20, robot_type="franka", features=LIBERO_FEATURES, ) logger.info(f"start processing for {input_h5}, saving to {output_path}") raw_dataset = load_local_episodes(input_h5) for episode_index, episode_data in enumerate(raw_dataset): with self.track_time("saving episode"): for frame_data in episode_data: frame_data["task"] = task_instruction dataset.add_frame(frame_data) dataset.save_episode() logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}") def create_aggr_dataset(raw_dirs: list[Path], aggregated_dir: Path): logger = setup_logger() all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in raw_dirs] fps, robot_type, features = validate_all_metadata(all_metadata) if aggregated_dir.exists(): shutil.rmtree(aggregated_dir) aggr_meta = LeRobotDatasetMetadata.create( repo_id=f"{aggregated_dir.parent.name}/{aggregated_dir.name}", root=aggregated_dir, fps=fps, robot_type=robot_type, features=features, ) video_keys = [key for key in features if features[key]["dtype"] == "video"] unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique() aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks) meta_idx = {"chunk": 0, "file": 0} data_idx = {"chunk": 0, "file": 0} videos_idx = {key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys} aggr_meta.episodes = {} for src_meta in tqdm(all_metadata, desc="Copy data and videos"): videos_idx = aggregate_videos( src_meta, aggr_meta, videos_idx, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE ) data_idx = aggregate_data(src_meta, aggr_meta, data_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE) meta_idx = aggregate_metadata(src_meta, aggr_meta, meta_idx, data_idx, videos_idx) aggr_meta.info["total_episodes"] += src_meta.total_episodes aggr_meta.info["total_frames"] += src_meta.total_frames logger.info("write tasks") write_tasks(aggr_meta.tasks, aggr_meta.root) logger.info("write info") aggr_meta.info.update( { "total_tasks": len(aggr_meta.tasks), "total_episodes": sum(m.total_episodes for m in all_metadata), "total_frames": sum(m.total_frames for m in all_metadata), "splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}, } ) write_info(aggr_meta.info, aggr_meta.root) logger.info("write stats") aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata]) write_stats(aggr_meta.stats, aggr_meta.root) def delete_temp_data(temp_dirs: list[Path]): logger = setup_logger() logger.info("Delete temp data_dir") for temp_dir in temp_dirs: shutil.rmtree(temp_dir) def main( src_paths: list[Path], output_path: Path, executor: str, cpus_per_task: int, tasks_per_job: int, workers: int, resume_dir: Path = None, debug: bool = False, repo_id: str = None, push_to_hub: bool = False, ): tasks = [] pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5") pattern2 = re.compile(r"(.*?)_demo\.hdf5") for src_path in src_paths: for input_h5 in src_path.glob("*.hdf5"): match = pattern1.search(input_h5.name) if match is None: match = pattern2.search(input_h5.name) if match is None: continue tasks.append( ( input_h5, (output_path / (src_path.name + "_temp") / input_h5.stem).resolve(), match.group(1).replace("_", " "), ) ) if len(src_paths) > 1: aggregate_output_path = output_path / ( "_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot" ) else: aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot" aggregate_output_path = aggregate_output_path.resolve() if debug: executor = "local" workers = 1 tasks = tasks[:2] push_to_hub = False match executor: case "local": workers = os.cpu_count() // cpus_per_task if workers == -1 else workers executor = LocalPipelineExecutor case "ray": runtime_env = RuntimeEnv( env_vars={ "HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE", "SVT_LOG": "1", }, ) ray.init(runtime_env=runtime_env) executor = RayPipelineExecutor case _: raise ValueError(f"Executor {executor} not supported") executor_config = { "tasks": len(tasks), "workers": workers, **({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}), } executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_dir).run() create_aggr_dataset([task[1] for task in tasks], aggregate_output_path) delete_temp_data([task[1] for task in tasks]) for task in tasks: shutil.rmtree(task[1].parent, ignore_errors=True) if push_to_hub: assert repo_id is not None tags = ["LeRobot", "libero", "franka"] tags.extend([src_path.name for src_path in src_paths]) LeRobotDataset( repo_id=repo_id, root=aggregate_output_path, ).push_to_hub( tags=tags, private=False, push_videos=True, license="apache-2.0", upload_large_folder=False, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src-paths", type=Path, nargs="+", required=True) parser.add_argument("--output-path", type=Path, required=True) parser.add_argument("--executor", type=str, choices=["local", "ray"], default="local") parser.add_argument("--cpus-per-task", type=int, default=1) parser.add_argument( "--tasks-per-job", type=int, default=1, help="number of concurrent tasks per job, only used for ray" ) parser.add_argument("--workers", type=int, default=-1, help="number of concurrent jobs to run") parser.add_argument("--resume-dir", type=Path, help="logs directory to resume") parser.add_argument("--debug", action="store_true") parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True") parser.add_argument("--push-to-hub", action="store_true", help="upload to hub") args = parser.parse_args() main(**vars(args))