mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-23 17:59:41 +00:00
⛏️ Fix libero2lerobot (#44)
* update debug logic * NIT: change class name * Update libero2lerobot/libero_h5.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
+73
-59
@@ -203,8 +203,16 @@ class AggregateDatasets(PipelineStep):
|
|||||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(video_path, aggr_video_path)
|
shutil.copy(video_path, aggr_video_path)
|
||||||
|
|
||||||
logger.info("Remove original data")
|
|
||||||
shutil.rmtree(meta.root)
|
class DeleteTempData(PipelineStep):
|
||||||
|
def __init__(self, temp_dirs: list[Path]):
|
||||||
|
self.temp_dirs = temp_dirs
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
logger.info(f"Delete temp data {self.temp_dirs[rank]}")
|
||||||
|
shutil.rmtree(self.temp_dirs[rank])
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@@ -214,19 +222,22 @@ def main(
|
|||||||
cpus_per_task: int,
|
cpus_per_task: int,
|
||||||
tasks_per_job: int,
|
tasks_per_job: int,
|
||||||
workers: int,
|
workers: int,
|
||||||
resume_from_save: Path,
|
resume_from_save: Path = None,
|
||||||
resume_from_aggregate: Path,
|
resume_from_aggregate: Path = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
repo_id: str = None,
|
repo_id: str = None,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
):
|
):
|
||||||
tasks = []
|
tasks = []
|
||||||
pattern = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
||||||
|
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
|
||||||
for src_path in src_paths:
|
for src_path in src_paths:
|
||||||
for input_h5 in src_path.glob("*.hdf5"):
|
for input_h5 in src_path.glob("*.hdf5"):
|
||||||
match = pattern.search(input_h5.name)
|
match = pattern1.search(input_h5.name)
|
||||||
if match is None:
|
if match is None:
|
||||||
continue
|
match = pattern2.search(input_h5.name)
|
||||||
|
if match is None:
|
||||||
|
continue
|
||||||
tasks.append(
|
tasks.append(
|
||||||
(
|
(
|
||||||
input_h5,
|
input_h5,
|
||||||
@@ -238,61 +249,64 @@ def main(
|
|||||||
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
||||||
else:
|
else:
|
||||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
SaveLerobotDataset([tasks[0]]).run()
|
executor = "local"
|
||||||
else:
|
workers = 1
|
||||||
save_config = {
|
tasks = tasks[:2]
|
||||||
"tasks": len(tasks),
|
push_to_hub = False
|
||||||
"workers": workers,
|
|
||||||
"logging_dir": resume_from_save,
|
|
||||||
}
|
|
||||||
aggregate_config = {
|
|
||||||
"tasks": len(tasks),
|
|
||||||
"workers": workers,
|
|
||||||
"logging_dir": resume_from_aggregate,
|
|
||||||
}
|
|
||||||
|
|
||||||
match executor:
|
match executor:
|
||||||
case "local":
|
case "local":
|
||||||
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
||||||
save_config["workers"] = workers
|
executor = LocalPipelineExecutor
|
||||||
aggregate_config["workers"] = workers
|
case "ray":
|
||||||
executor = LocalPipelineExecutor
|
runtime_env = RuntimeEnv(
|
||||||
case "ray":
|
env_vars={
|
||||||
runtime_env = RuntimeEnv(
|
"HDF5_USE_FILE_LOCKING": "FALSE",
|
||||||
env_vars={
|
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||||
"HDF5_USE_FILE_LOCKING": "FALSE",
|
"SVT_LOG": "1",
|
||||||
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
},
|
||||||
"SVT_LOG": "1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ray.init(runtime_env=runtime_env)
|
|
||||||
save_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
|
|
||||||
aggregate_config.update({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job})
|
|
||||||
executor = RayPipelineExecutor
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Executor {executor} not supported")
|
|
||||||
|
|
||||||
executor(pipeline=[SaveLerobotDataset(tasks)], **save_config).run()
|
|
||||||
executor(pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)], **aggregate_config).run()
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
shutil.rmtree(task[1].parent, ignore_errors=True)
|
|
||||||
|
|
||||||
if push_to_hub:
|
|
||||||
assert repo_id is not None
|
|
||||||
tags = ["LeRobot", "libero", "franka"]
|
|
||||||
tags.extend([src_path.name for src_path in src_paths])
|
|
||||||
LeRobotDataset(
|
|
||||||
repo_id=repo_id,
|
|
||||||
root=aggregate_output_path,
|
|
||||||
).push_to_hub(
|
|
||||||
tags=tags,
|
|
||||||
private=False,
|
|
||||||
push_videos=True,
|
|
||||||
license="apache-2.0",
|
|
||||||
upload_large_folder=False,
|
|
||||||
)
|
)
|
||||||
|
ray.init(runtime_env=runtime_env)
|
||||||
|
executor = RayPipelineExecutor
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Executor {executor} not supported")
|
||||||
|
|
||||||
|
executor_config = {
|
||||||
|
"tasks": len(tasks),
|
||||||
|
"workers": workers,
|
||||||
|
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run()
|
||||||
|
executor(
|
||||||
|
pipeline=[DeleteTempData([task[1] for task in tasks])],
|
||||||
|
**executor_config,
|
||||||
|
depends=executor(
|
||||||
|
pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)],
|
||||||
|
**executor_config,
|
||||||
|
logging_dir=resume_from_aggregate,
|
||||||
|
),
|
||||||
|
).run()
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
shutil.rmtree(task[1].parent, ignore_errors=True)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
assert repo_id is not None
|
||||||
|
tags = ["LeRobot", "libero", "franka"]
|
||||||
|
tags.extend([src_path.name for src_path in src_paths])
|
||||||
|
LeRobotDataset(
|
||||||
|
repo_id=repo_id,
|
||||||
|
root=aggregate_output_path,
|
||||||
|
).push_to_hub(
|
||||||
|
tags=tags,
|
||||||
|
private=False,
|
||||||
|
push_videos=True,
|
||||||
|
license="apache-2.0",
|
||||||
|
upload_large_folder=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user