diff --git a/libero2lerobot/libero_h5.py b/libero2lerobot/libero_h5.py index 4ea2df1..f7d10cd 100644 --- a/libero2lerobot/libero_h5.py +++ b/libero2lerobot/libero_h5.py @@ -4,6 +4,7 @@ import re import shutil from pathlib import Path +import numpy as np import pandas as pd import ray from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor @@ -115,30 +116,45 @@ class AggregateDatasets(PipelineStep): datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index - datasets_ep_idx_to_aggr_ep_idx = {} datasets_aggr_episode_index_shift = {} datasets_aggr_index_shift = {} - aggr_episode_index_shift = 0 for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")): - ep_idx_to_aggr_ep_idx = {} - - for episode_index in range(meta.total_episodes): - aggr_episode_index = episode_index + aggr_episode_index_shift - ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index - - datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx - datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift + datasets_aggr_episode_index_shift[dataset_index] = aggr_meta.total_episodes datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames # populate episodes for episode_index, episode_dict in meta.episodes.items(): - aggr_episode_index = episode_index + aggr_episode_index_shift + aggr_episode_index = episode_index + aggr_meta.total_episodes episode_dict["episode_index"] = aggr_episode_index aggr_meta.episodes[aggr_episode_index] = episode_dict # populate episodes_stats for episode_index, episode_stats in meta.episodes_stats.items(): - aggr_episode_index = episode_index + aggr_episode_index_shift + aggr_episode_index = episode_index + aggr_meta.total_episodes + episode_stats["index"].update( + { + "min": episode_stats["index"]["min"] + aggr_meta.total_frames, + "max": episode_stats["index"]["max"] + aggr_meta.total_frames, + "mean": episode_stats["index"]["mean"] + aggr_meta.total_frames, + } + ) + episode_stats["episode_index"].update( + { + "min": np.array([aggr_episode_index]), + "max": np.array([aggr_episode_index]), + "mean": np.array([aggr_episode_index]), + } + ) + df = pd.read_parquet(meta.root / meta.get_data_file_path(episode_index)) + df["task_index"] = df["task_index"].map(datasets_task_index_to_aggr_task_index[dataset_index]) + episode_stats["task_index"].update( + { + "min": np.array([df["task_index"].min()]), + "max": np.array([df["task_index"].max()]), + "mean": np.array([df["task_index"].mean()]), + "std": np.array([df["task_index"].std()]), + } + ) aggr_meta.episodes_stats[aggr_episode_index] = episode_stats # populate info @@ -146,11 +162,9 @@ class AggregateDatasets(PipelineStep): aggr_meta.info["total_frames"] += meta.total_frames aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes - aggr_episode_index_shift += meta.total_episodes - logger.info("Write meta data") aggr_meta.info["total_tasks"] = len(aggr_meta.tasks) - aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1) + aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_meta.total_episodes - 1) + 1 aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"} # create a new episodes jsonl with updated episode_index using write_episode @@ -168,7 +182,6 @@ class AggregateDatasets(PipelineStep): write_info(aggr_meta.info, aggr_meta.root) self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index - self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift self.datasets_aggr_index_shift = datasets_aggr_index_shift @@ -187,7 +200,7 @@ class AggregateDatasets(PipelineStep): with self.track_time("aggregating dataset"): logger.info("Copy data") for episode_index in range(meta.total_episodes): - aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index] + aggr_episode_index = episode_index + aggr_episode_index_shift data_path = meta.root / meta.get_data_file_path(episode_index) aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index) aggr_data_path.parent.mkdir(parents=True, exist_ok=True) @@ -258,6 +271,7 @@ def main( aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot") else: aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot" + aggregate_output_path = aggregate_output_path.resolve() if debug: executor = "local"