mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-23 17:59:41 +00:00
337 lines
13 KiB
Python
337 lines
13 KiB
Python
import argparse
|
|
import os
|
|
import re
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
import ray
|
|
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
|
from datatrove.pipeline.base import PipelineStep
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
|
from lerobot.common.datasets.utils import (
|
|
write_episode,
|
|
write_episode_stats,
|
|
write_info,
|
|
write_task,
|
|
)
|
|
from libero_utils.config import LIBERO_FEATURES
|
|
from libero_utils.lerobot_utils import validate_all_metadata
|
|
from libero_utils.libero_utils import load_local_episodes
|
|
from ray.runtime_env import RuntimeEnv
|
|
from tqdm import tqdm
|
|
|
|
|
|
def setup_logger():
|
|
import sys
|
|
|
|
from datatrove.utils.logging import logger
|
|
|
|
logger.remove()
|
|
logger.add(sys.stdout, level="INFO", colorize=True)
|
|
return logger
|
|
|
|
|
|
class SaveLerobotDataset(PipelineStep):
|
|
name = "Save Temp LerobotDataset"
|
|
type = "libero2lerobot"
|
|
|
|
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
|
super().__init__()
|
|
self.tasks = tasks
|
|
|
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
|
logger = setup_logger()
|
|
|
|
input_h5, output_path, task_instruction = self.tasks[rank]
|
|
|
|
if output_path.exists():
|
|
shutil.rmtree(output_path)
|
|
|
|
dataset = LeRobotDataset.create(
|
|
repo_id=f"{input_h5.parent.name}/{input_h5.name}",
|
|
root=output_path,
|
|
fps=20,
|
|
robot_type="franka",
|
|
features=LIBERO_FEATURES,
|
|
)
|
|
|
|
logger.info(f"start processing for {input_h5}, saving to {output_path}")
|
|
|
|
raw_dataset = load_local_episodes(input_h5)
|
|
for episode_index, episode_data in enumerate(raw_dataset):
|
|
with self.track_time("saving episode"):
|
|
for frame_data in episode_data:
|
|
dataset.add_frame(
|
|
frame_data,
|
|
task=task_instruction,
|
|
)
|
|
dataset.save_episode()
|
|
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
|
|
|
|
|
|
class AggregateDatasets(PipelineStep):
|
|
name = "Aggregate Datasets"
|
|
type = "libero2lerobot"
|
|
|
|
def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
|
|
super().__init__()
|
|
self.raw_dirs = raw_dirs
|
|
self.aggregated_dir = aggregated_dir
|
|
|
|
self.create_aggr_dataset()
|
|
|
|
def create_aggr_dataset(self):
|
|
logger = setup_logger()
|
|
|
|
all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in self.raw_dirs]
|
|
|
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
|
|
|
if self.aggregated_dir.exists():
|
|
shutil.rmtree(self.aggregated_dir)
|
|
|
|
aggr_meta = LeRobotDatasetMetadata.create(
|
|
repo_id=f"{self.aggregated_dir.parent.name}/{self.aggregated_dir.name}",
|
|
root=self.aggregated_dir,
|
|
fps=fps,
|
|
robot_type=robot_type,
|
|
features=features,
|
|
)
|
|
|
|
datasets_task_index_to_aggr_task_index = {}
|
|
aggr_task_index = 0
|
|
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate tasks index")):
|
|
task_index_to_aggr_task_index = {}
|
|
|
|
for task_index, task in meta.tasks.items():
|
|
if task not in aggr_meta.task_to_task_index:
|
|
# add the task to aggr tasks mappings
|
|
aggr_meta.tasks[aggr_task_index] = task
|
|
aggr_meta.task_to_task_index[task] = aggr_task_index
|
|
aggr_task_index += 1
|
|
|
|
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
|
|
|
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
|
|
|
datasets_ep_idx_to_aggr_ep_idx = {}
|
|
datasets_aggr_episode_index_shift = {}
|
|
datasets_aggr_index_shift = {}
|
|
aggr_episode_index_shift = 0
|
|
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
|
|
ep_idx_to_aggr_ep_idx = {}
|
|
|
|
for episode_index in range(meta.total_episodes):
|
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
|
|
|
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
|
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
|
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
|
|
|
|
# populate episodes
|
|
for episode_index, episode_dict in meta.episodes.items():
|
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
episode_dict["episode_index"] = aggr_episode_index
|
|
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
|
|
|
# populate episodes_stats
|
|
for episode_index, episode_stats in meta.episodes_stats.items():
|
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
|
|
|
# populate info
|
|
aggr_meta.info["total_episodes"] += meta.total_episodes
|
|
aggr_meta.info["total_frames"] += meta.total_frames
|
|
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
|
|
|
aggr_episode_index_shift += meta.total_episodes
|
|
|
|
logger.info("Write meta data")
|
|
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
|
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
|
|
|
# create a new episodes jsonl with updated episode_index using write_episode
|
|
for episode_dict in tqdm(aggr_meta.episodes.values(), desc="Write episodes info"):
|
|
write_episode(episode_dict, aggr_meta.root)
|
|
|
|
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
|
for episode_index, episode_stats in tqdm(aggr_meta.episodes_stats.items(), desc="Write episodes stats info"):
|
|
write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
|
|
|
# create a new task jsonl with updated episode_index using write_task
|
|
for task_index, task in tqdm(aggr_meta.tasks.items(), desc="Write tasks info"):
|
|
write_task(task_index, task, aggr_meta.root)
|
|
|
|
write_info(aggr_meta.info, aggr_meta.root)
|
|
|
|
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
|
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
|
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
|
self.datasets_aggr_index_shift = datasets_aggr_index_shift
|
|
|
|
logger.info("Meta data done writing")
|
|
|
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
|
logger = setup_logger()
|
|
|
|
dataset_index = rank
|
|
aggr_meta = LeRobotDatasetMetadata("", root=self.aggregated_dir)
|
|
meta = LeRobotDatasetMetadata("", root=self.raw_dirs[dataset_index])
|
|
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
|
|
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
|
|
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
|
|
|
|
with self.track_time("aggregating dataset"):
|
|
logger.info("Copy data")
|
|
for episode_index in range(meta.total_episodes):
|
|
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
|
data_path = meta.root / meta.get_data_file_path(episode_index)
|
|
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
|
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# update index, episode_index and task_index
|
|
df = pd.read_parquet(data_path)
|
|
df["index"] += aggr_index_shift
|
|
df["episode_index"] += aggr_episode_index_shift
|
|
df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index)
|
|
df.to_parquet(aggr_data_path)
|
|
|
|
logger.info("Copy videos")
|
|
for episode_index in range(meta.total_episodes):
|
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
for vid_key in meta.video_keys:
|
|
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
|
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
|
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy(video_path, aggr_video_path)
|
|
|
|
|
|
class DeleteTempData(PipelineStep):
|
|
name = "Delete Temp Data"
|
|
type = "libero2lerobot"
|
|
|
|
def __init__(self, temp_dirs: list[Path]):
|
|
super().__init__()
|
|
self.temp_dirs = temp_dirs
|
|
|
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
|
logger = setup_logger()
|
|
|
|
logger.info(f"Delete temp data {self.temp_dirs[rank]}")
|
|
shutil.rmtree(self.temp_dirs[rank])
|
|
|
|
|
|
def main(
|
|
src_paths: list[Path],
|
|
output_path: Path,
|
|
executor: str,
|
|
cpus_per_task: int,
|
|
tasks_per_job: int,
|
|
workers: int,
|
|
resume_from_save: Path = None,
|
|
resume_from_aggregate: Path = None,
|
|
debug: bool = False,
|
|
repo_id: str = None,
|
|
push_to_hub: bool = False,
|
|
):
|
|
tasks = []
|
|
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
|
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
|
|
for src_path in src_paths:
|
|
for input_h5 in src_path.glob("*.hdf5"):
|
|
match = pattern1.search(input_h5.name)
|
|
if match is None:
|
|
match = pattern2.search(input_h5.name)
|
|
if match is None:
|
|
continue
|
|
tasks.append(
|
|
(
|
|
input_h5,
|
|
(output_path / (src_path.name + "_temp") / input_h5.stem).resolve(),
|
|
match.group(1).replace("_", " "),
|
|
)
|
|
)
|
|
if len(src_paths) > 1:
|
|
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
|
else:
|
|
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
|
|
|
if debug:
|
|
executor = "local"
|
|
workers = 1
|
|
tasks = tasks[:2]
|
|
push_to_hub = False
|
|
|
|
match executor:
|
|
case "local":
|
|
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
|
executor = LocalPipelineExecutor
|
|
case "ray":
|
|
runtime_env = RuntimeEnv(
|
|
env_vars={
|
|
"HDF5_USE_FILE_LOCKING": "FALSE",
|
|
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
|
"SVT_LOG": "1",
|
|
},
|
|
)
|
|
ray.init(runtime_env=runtime_env)
|
|
executor = RayPipelineExecutor
|
|
case _:
|
|
raise ValueError(f"Executor {executor} not supported")
|
|
|
|
executor_config = {
|
|
"tasks": len(tasks),
|
|
"workers": workers,
|
|
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
|
|
}
|
|
|
|
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_from_save).run()
|
|
executor(
|
|
pipeline=[DeleteTempData([task[1] for task in tasks])],
|
|
**executor_config,
|
|
depends=executor(
|
|
pipeline=[AggregateDatasets([task[1] for task in tasks], aggregate_output_path)],
|
|
**executor_config,
|
|
logging_dir=resume_from_aggregate,
|
|
),
|
|
).run()
|
|
|
|
for task in tasks:
|
|
shutil.rmtree(task[1].parent, ignore_errors=True)
|
|
|
|
if push_to_hub:
|
|
assert repo_id is not None
|
|
tags = ["LeRobot", "libero", "franka"]
|
|
tags.extend([src_path.name for src_path in src_paths])
|
|
LeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=aggregate_output_path,
|
|
).push_to_hub(
|
|
tags=tags,
|
|
private=False,
|
|
push_videos=True,
|
|
license="apache-2.0",
|
|
upload_large_folder=False,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--src-paths", type=Path, nargs="+", required=True)
|
|
parser.add_argument("--output-path", type=Path, required=True)
|
|
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="local")
|
|
parser.add_argument("--cpus-per-task", type=int, default=1)
|
|
parser.add_argument("--tasks-per-job", type=int, default=1, help="number of concurrent tasks per job, only used for ray")
|
|
parser.add_argument("--workers", type=int, default=-1, help="number of concurrent jobs to run")
|
|
parser.add_argument("--resume-from-save", type=Path, help="logs directory to resume from save step")
|
|
parser.add_argument("--resume-from-aggregate", type=Path, help="logs directory to resume from aggregate step")
|
|
parser.add_argument("--debug", action="store_true")
|
|
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
|
|
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
|
|
args = parser.parse_args()
|
|
|
|
main(**vars(args))
|