⛏️ Fix libero2lerobot (#44)

* update debug logic

* NIT: change class name

* Update libero2lerobot/libero_h5.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Qizhi Chen
2025-06-28 11:33:51 +08:00
committed by GitHub
parent 4dc21b9b70
commit d7694a54cc
+73 -59
View File
@@ -203,8 +203,16 @@ class AggregateDatasets(PipelineStep):
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
logger.info("Remove original data")
shutil.rmtree(meta.root)
class DeleteTempData(PipelineStep):
def __init__(self, temp_dirs: list[Path]):
self.temp_dirs = temp_dirs
def run(self, data=None, rank: int = 0, world_size: int = 1):
logger = setup_logger()
logger.info(f"Delete temp data {self.temp_dirs[rank]}")
shutil.rmtree(self.temp_dirs[rank])
def main(
@@ -214,19 +222,22 @@ def main(
cpus_per_task: int,
tasks_per_job: int,
workers: int,
resume_from_save: Path,
resume_from_aggregate: Path,
resume_from_save: Path = None,
resume_from_aggregate: Path = None,
debug: bool = False,
repo_id: str = None,
push_to_hub: bool = False,
):
tasks = []
pattern = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
for src_path in src_paths:
for input_h5 in src_path.glob("*.hdf5"):
match = pattern.search(input_h5.name)
match = pattern1.search(input_h5.name)
if match is None:
continue
match = pattern2.search(input_h5.name)
if match is None:
continue
tasks.append(
(
input_h5,
@@ -238,61 +249,64 @@ def main(
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
else:
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
if debug:
SaveLerobotDataset([tasks[0]]).run()
else:
save_config = {
"tasks": len(tasks),
"workers": workers,
"logging_dir": resume_from_save,
}
aggregate_config = {
"tasks": len(tasks),
"workers": workers,
"logging_dir": resume_from_aggregate,
}
executor = "local"
workers = 1
tasks = tasks[:2]
push_to_hub = False
match executor:
case "local":
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
save_config["workers"] = workers
aggregate_config["workers"] = workers
executor = LocalPipelineExecutor
case "ray":
runtime_env = RuntimeEnv(
env_vars={
"HDF5_USE_FILE_LOCKING": "FALSE",
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
"SVT_LOG": "1",
},
)
ray.init(runtime_env=runtime_env)
save_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
aggregate_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
executor = RayPipelineExecutor
case _:
raise ValueError(f"Executor {executor} not supported")
executor(pipeline=[SaveLerobotDataset(tasks)], **save_config).run()
executor(pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)], **aggregate_config).run()
for task in tasks:
shutil.rmtree(task[1].parent, ignore_errors=True)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", "libero", "franka"]
tags.extend([src_path.name for src_path in src_paths])
LeRobotDataset(
repo_id=repo_id,
root=aggregate_output_path,
).push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
upload_large_folder=False,
match executor:
case "local":
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
executor = LocalPipelineExecutor
case "ray":
runtime_env = RuntimeEnv(
env_vars={
"HDF5_USE_FILE_LOCKING": "FALSE",
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
"SVT_LOG": "1",
},
)
ray.init(runtime_env=runtime_env)
executor = RayPipelineExecutor
case _:
raise ValueError(f"Executor {executor} not supported")
executor_config = {
"tasks": len(tasks),
"workers": workers,
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
}
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run()
executor(
pipeline=[DeleteTempData([task[1] for task in tasks])],
**executor_config,
depends=executor(
pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)],
**executor_config,
logging_dir=resume_from_aggregate,
),
).run()
for task in tasks:
shutil.rmtree(task[1].parent, ignore_errors=True)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", "libero", "franka"]
tags.extend([src_path.name for src_path in src_paths])
LeRobotDataset(
repo_id=repo_id,
root=aggregate_output_path,
).push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
upload_large_folder=False,
)
if __name__ == "__main__":