mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
⛏️ Fix libero2lerobot (#44)
* update debug logic * NIT: change class name * Update libero2lerobot/libero_h5.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
+73
-59
@@ -203,8 +203,16 @@ class AggregateDatasets(PipelineStep):
|
||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(video_path, aggr_video_path)
|
||||
|
||||
logger.info("Remove original data")
|
||||
shutil.rmtree(meta.root)
|
||||
|
||||
class DeleteTempData(PipelineStep):
|
||||
def __init__(self, temp_dirs: list[Path]):
|
||||
self.temp_dirs = temp_dirs
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
logger = setup_logger()
|
||||
|
||||
logger.info(f"Delete temp data {self.temp_dirs[rank]}")
|
||||
shutil.rmtree(self.temp_dirs[rank])
|
||||
|
||||
|
||||
def main(
|
||||
@@ -214,19 +222,22 @@ def main(
|
||||
cpus_per_task: int,
|
||||
tasks_per_job: int,
|
||||
workers: int,
|
||||
resume_from_save: Path,
|
||||
resume_from_aggregate: Path,
|
||||
resume_from_save: Path = None,
|
||||
resume_from_aggregate: Path = None,
|
||||
debug: bool = False,
|
||||
repo_id: str = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
tasks = []
|
||||
pattern = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
||||
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
||||
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
|
||||
for src_path in src_paths:
|
||||
for input_h5 in src_path.glob("*.hdf5"):
|
||||
match = pattern.search(input_h5.name)
|
||||
match = pattern1.search(input_h5.name)
|
||||
if match is None:
|
||||
continue
|
||||
match = pattern2.search(input_h5.name)
|
||||
if match is None:
|
||||
continue
|
||||
tasks.append(
|
||||
(
|
||||
input_h5,
|
||||
@@ -238,61 +249,64 @@ def main(
|
||||
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
||||
else:
|
||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||
|
||||
if debug:
|
||||
SaveLerobotDataset([tasks[0]]).run()
|
||||
else:
|
||||
save_config = {
|
||||
"tasks": len(tasks),
|
||||
"workers": workers,
|
||||
"logging_dir": resume_from_save,
|
||||
}
|
||||
aggregate_config = {
|
||||
"tasks": len(tasks),
|
||||
"workers": workers,
|
||||
"logging_dir": resume_from_aggregate,
|
||||
}
|
||||
executor = "local"
|
||||
workers = 1
|
||||
tasks = tasks[:2]
|
||||
push_to_hub = False
|
||||
|
||||
match executor:
|
||||
case "local":
|
||||
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
||||
save_config["workers"] = workers
|
||||
aggregate_config["workers"] = workers
|
||||
executor = LocalPipelineExecutor
|
||||
case "ray":
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars={
|
||||
"HDF5_USE_FILE_LOCKING": "FALSE",
|
||||
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||
"SVT_LOG": "1",
|
||||
},
|
||||
)
|
||||
ray.init(runtime_env=runtime_env)
|
||||
save_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
|
||||
aggregate_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
|
||||
executor = RayPipelineExecutor
|
||||
case _:
|
||||
raise ValueError(f"Executor {executor} not supported")
|
||||
|
||||
executor(pipeline=[SaveLerobotDataset(tasks)], **save_config).run()
|
||||
executor(pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)], **aggregate_config).run()
|
||||
|
||||
for task in tasks:
|
||||
shutil.rmtree(task[1].parent, ignore_errors=True)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", "libero", "franka"]
|
||||
tags.extend([src_path.name for src_path in src_paths])
|
||||
LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=aggregate_output_path,
|
||||
).push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
upload_large_folder=False,
|
||||
match executor:
|
||||
case "local":
|
||||
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
||||
executor = LocalPipelineExecutor
|
||||
case "ray":
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars={
|
||||
"HDF5_USE_FILE_LOCKING": "FALSE",
|
||||
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||
"SVT_LOG": "1",
|
||||
},
|
||||
)
|
||||
ray.init(runtime_env=runtime_env)
|
||||
executor = RayPipelineExecutor
|
||||
case _:
|
||||
raise ValueError(f"Executor {executor} not supported")
|
||||
|
||||
executor_config = {
|
||||
"tasks": len(tasks),
|
||||
"workers": workers,
|
||||
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
|
||||
}
|
||||
|
||||
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run()
|
||||
executor(
|
||||
pipeline=[DeleteTempData([task[1] for task in tasks])],
|
||||
**executor_config,
|
||||
depends=executor(
|
||||
pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)],
|
||||
**executor_config,
|
||||
logging_dir=resume_from_aggregate,
|
||||
),
|
||||
).run()
|
||||
|
||||
for task in tasks:
|
||||
shutil.rmtree(task[1].parent, ignore_errors=True)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", "libero", "franka"]
|
||||
tags.extend([src_path.name for src_path in src_paths])
|
||||
LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=aggregate_output_path,
|
||||
).push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
upload_large_folder=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user