diff --git a/libero2lerobot/libero_h5.py b/libero2lerobot/libero_h5.py index bbff34b..a918e39 100644 --- a/libero2lerobot/libero_h5.py +++ b/libero2lerobot/libero_h5.py @@ -203,8 +203,16 @@ class AggregateDatasets(PipelineStep): aggr_video_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(video_path, aggr_video_path) - logger.info("Remove original data") - shutil.rmtree(meta.root) + +class DeleteTempData(PipelineStep): + def __init__(self, temp_dirs: list[Path]): + self.temp_dirs = temp_dirs + + def run(self, data=None, rank: int = 0, world_size: int = 1): + logger = setup_logger() + + logger.info(f"Delete temp data {self.temp_dirs[rank]}") + shutil.rmtree(self.temp_dirs[rank]) def main( @@ -214,19 +222,22 @@ def main( cpus_per_task: int, tasks_per_job: int, workers: int, - resume_from_save: Path, - resume_from_aggregate: Path, + resume_from_save: Path = None, + resume_from_aggregate: Path = None, debug: bool = False, repo_id: str = None, push_to_hub: bool = False, ): tasks = [] - pattern = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5") + pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5") + pattern2 = re.compile(r"(.*?)_demo\.hdf5") for src_path in src_paths: for input_h5 in src_path.glob("*.hdf5"): - match = pattern.search(input_h5.name) + match = pattern1.search(input_h5.name) if match is None: - continue + match = pattern2.search(input_h5.name) + if match is None: + continue tasks.append( ( input_h5, @@ -238,61 +249,64 @@ def main( aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot") else: aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot" + if debug: - SaveLerobotDataset([tasks[0]]).run() - else: - save_config = { - "tasks": len(tasks), - "workers": workers, - "logging_dir": resume_from_save, - } - aggregate_config = { - "tasks": len(tasks), - "workers": workers, - "logging_dir": resume_from_aggregate, - } + executor = "local" + workers = 1 + tasks = tasks[:2] + push_to_hub = False - match executor: - case "local": - workers = os.cpu_count() // cpus_per_task if workers == -1 else workers - save_config["workers"] = workers - aggregate_config["workers"] = workers - executor = LocalPipelineExecutor - case "ray": - runtime_env = RuntimeEnv( - env_vars={ - "HDF5_USE_FILE_LOCKING": "FALSE", - "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE", - "SVT_LOG": "1", - }, - ) - ray.init(runtime_env=runtime_env) - save_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job}) - aggregate_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job}) - executor = RayPipelineExecutor - case _: - raise ValueError(f"Executor {executor} not supported") - - executor(pipeline=[SaveLerobotDataset(tasks)], **save_config).run() - executor(pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)], **aggregate_config).run() - - for task in tasks: - shutil.rmtree(task[1].parent, ignore_errors=True) - - if push_to_hub: - assert repo_id is not None - tags = ["LeRobot", "libero", "franka"] - tags.extend([src_path.name for src_path in src_paths]) - LeRobotDataset( - repo_id=repo_id, - root=aggregate_output_path, - ).push_to_hub( - tags=tags, - private=False, - push_videos=True, - license="apache-2.0", - upload_large_folder=False, + match executor: + case "local": + workers = os.cpu_count() // cpus_per_task if workers == -1 else workers + executor = LocalPipelineExecutor + case "ray": + runtime_env = RuntimeEnv( + env_vars={ + "HDF5_USE_FILE_LOCKING": "FALSE", + "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE", + "SVT_LOG": "1", + }, ) + ray.init(runtime_env=runtime_env) + executor = RayPipelineExecutor + case _: + raise ValueError(f"Executor {executor} not supported") + + executor_config = { + "tasks": len(tasks), + "workers": workers, + **({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}), + } + + executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run() + executor( + pipeline=[DeleteTempData([task[1] for task in tasks])], + **executor_config, + depends=executor( + pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)], + **executor_config, + logging_dir=resume_from_aggregate, + ), + ).run() + + for task in tasks: + shutil.rmtree(task[1].parent, ignore_errors=True) + + if push_to_hub: + assert repo_id is not None + tags = ["LeRobot", "libero", "franka"] + tags.extend([src_path.name for src_path in src_paths]) + LeRobotDataset( + repo_id=repo_id, + root=aggregate_output_path, + ).push_to_hub( + tags=tags, + private=False, + push_videos=True, + license="apache-2.0", + upload_large_folder=False, + ) if __name__ == "__main__":