Files
any4lerobot/libero2lerobot/libero_h5.py
T
2025-07-01 14:27:11 +08:00

337 lines
13 KiB
Python

import argparse
import os
import re
import shutil
from pathlib import Path
import pandas as pd
import ray
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
write_episode,
write_episode_stats,
write_info,
write_task,
)
from libero_utils.config import LIBERO_FEATURES
from libero_utils.lerobot_utils import validate_all_metadata
from libero_utils.libero_utils import load_local_episodes
from ray.runtime_env import RuntimeEnv
from tqdm import tqdm
def setup_logger():
import sys
from datatrove.utils.logging import logger
logger.remove()
logger.add(sys.stdout, level="INFO", colorize=True)
return logger
class SaveLerobotDataset(PipelineStep):
name = "Save Temp LerobotDataset"
type = "libero2lerobot"
def __init__(self, tasks: list[tuple[Path, Path, str]]):
super().__init__()
self.tasks = tasks
def run(self, data=None, rank: int = 0, world_size: int = 1):
logger = setup_logger()
input_h5, output_path, task_instruction = self.tasks[rank]
if output_path.exists():
shutil.rmtree(output_path)
dataset = LeRobotDataset.create(
repo_id=f"{input_h5.parent.name}/{input_h5.name}",
root=output_path,
fps=20,
robot_type="franka",
features=LIBERO_FEATURES,
)
logger.info(f"start processing for {input_h5}, saving to {output_path}")
raw_dataset = load_local_episodes(input_h5)
for episode_index, episode_data in enumerate(raw_dataset):
with self.track_time("saving episode"):
for frame_data in episode_data:
dataset.add_frame(
frame_data,
task=task_instruction,
)
dataset.save_episode()
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
class AggregateDatasets(PipelineStep):
name = "Aggregate Datasets"
type = "libero2lerobot"
def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
super().__init__()
self.raw_dirs = raw_dirs
self.aggregated_dir = aggregated_dir
self.create_aggr_dataset()
def create_aggr_dataset(self):
logger = setup_logger()
all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in self.raw_dirs]
fps, robot_type, features = validate_all_metadata(all_metadata)
if self.aggregated_dir.exists():
shutil.rmtree(self.aggregated_dir)
aggr_meta = LeRobotDatasetMetadata.create(
repo_id=f"{self.aggregated_dir.parent.name}/{self.aggregated_dir.name}",
root=self.aggregated_dir,
fps=fps,
robot_type=robot_type,
features=features,
)
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate tasks index")):
task_index_to_aggr_task_index = {}
for task_index, task in meta.tasks.items():
if task not in aggr_meta.task_to_task_index:
# add the task to aggr tasks mappings
aggr_meta.tasks[aggr_task_index] = task
aggr_meta.task_to_task_index[task] = aggr_task_index
aggr_task_index += 1
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
datasets_ep_idx_to_aggr_ep_idx = {}
datasets_aggr_episode_index_shift = {}
datasets_aggr_index_shift = {}
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
ep_idx_to_aggr_ep_idx = {}
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict
# populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
# populate info
aggr_meta.info["total_episodes"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logger.info("Write meta data")
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode
for episode_dict in tqdm(aggr_meta.episodes.values(), desc="Write episodes info"):
write_episode(episode_dict, aggr_meta.root)
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
for episode_index, episode_stats in tqdm(aggr_meta.episodes_stats.items(), desc="Write episodes stats info"):
write_episode_stats(episode_index, episode_stats, aggr_meta.root)
# create a new task jsonl with updated episode_index using write_task
for task_index, task in tqdm(aggr_meta.tasks.items(), desc="Write tasks info"):
write_task(task_index, task, aggr_meta.root)
write_info(aggr_meta.info, aggr_meta.root)
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
self.datasets_aggr_index_shift = datasets_aggr_index_shift
logger.info("Meta data done writing")
def run(self, data=None, rank: int = 0, world_size: int = 1):
logger = setup_logger()
dataset_index = rank
aggr_meta = LeRobotDatasetMetadata("", root=self.aggregated_dir)
meta = LeRobotDatasetMetadata("", root=self.raw_dirs[dataset_index])
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
with self.track_time("aggregating dataset"):
logger.info("Copy data")
for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
# update index, episode_index and task_index
df = pd.read_parquet(data_path)
df["index"] += aggr_index_shift
df["episode_index"] += aggr_episode_index_shift
df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index)
df.to_parquet(aggr_data_path)
logger.info("Copy videos")
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
class DeleteTempData(PipelineStep):
name = "Delete Temp Data"
type = "libero2lerobot"
def __init__(self, temp_dirs: list[Path]):
super().__init__()
self.temp_dirs = temp_dirs
def run(self, data=None, rank: int = 0, world_size: int = 1):
logger = setup_logger()
logger.info(f"Delete temp data {self.temp_dirs[rank]}")
shutil.rmtree(self.temp_dirs[rank])
def main(
src_paths: list[Path],
output_path: Path,
executor: str,
cpus_per_task: int,
tasks_per_job: int,
workers: int,
resume_from_save: Path = None,
resume_from_aggregate: Path = None,
debug: bool = False,
repo_id: str = None,
push_to_hub: bool = False,
):
tasks = []
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
for src_path in src_paths:
for input_h5 in src_path.glob("*.hdf5"):
match = pattern1.search(input_h5.name)
if match is None:
match = pattern2.search(input_h5.name)
if match is None:
continue
tasks.append(
(
input_h5,
(output_path / (src_path.name + "_temp") / input_h5.stem).resolve(),
match.group(1).replace("_", " "),
)
)
if len(src_paths) > 1:
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
else:
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
if debug:
executor = "local"
workers = 1
tasks = tasks[:2]
push_to_hub = False
match executor:
case "local":
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
executor = LocalPipelineExecutor
case "ray":
runtime_env = RuntimeEnv(
env_vars={
"HDF5_USE_FILE_LOCKING": "FALSE",
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
"SVT_LOG": "1",
},
)
ray.init(runtime_env=runtime_env)
executor = RayPipelineExecutor
case _:
raise ValueError(f"Executor {executor} not supported")
executor_config = {
"tasks": len(tasks),
"workers": workers,
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
}
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run()
executor(
pipeline=[DeleteTempData([task[1] for task in tasks])],
**executor_config,
depends=executor(
pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)],
**executor_config,
logging_dir=resume_from_aggregate,
),
).run()
for task in tasks:
shutil.rmtree(task[1].parent, ignore_errors=True)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", "libero", "franka"]
tags.extend([src_path.name for src_path in src_paths])
LeRobotDataset(
repo_id=repo_id,
root=aggregate_output_path,
).push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
upload_large_folder=False,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src-paths", type=Path, nargs="+", required=True)
parser.add_argument("--output-path", type=Path, required=True)
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="local")
parser.add_argument("--cpus-per-task", type=int, default=1)
parser.add_argument("--tasks-per-job", type=int, default=1, help="number of concurrent tasks per job, only used for ray")
parser.add_argument("--workers", type=int, default=-1, help="number of concurrent jobs to run")
parser.add_argument("--resume-from-save", type=Path, help="logs directory to resume from save step")
parser.add_argument("--resume-from-aggregate", type=Path, help="logs directory to resume from aggregate step")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
args = parser.parse_args()
main(**vars(args))