🔨 fix meta && enhance aggr

This commit is contained in:
Tavish
2025-07-11 20:05:12 +08:00
parent 30981e2111
commit cbe3e13375
+31 -17
View File
@@ -4,6 +4,7 @@ import re
import shutil import shutil
from pathlib import Path from pathlib import Path
import numpy as np
import pandas as pd import pandas as pd
import ray import ray
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
@@ -115,30 +116,45 @@ class AggregateDatasets(PipelineStep):
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
datasets_ep_idx_to_aggr_ep_idx = {}
datasets_aggr_episode_index_shift = {} datasets_aggr_episode_index_shift = {}
datasets_aggr_index_shift = {} datasets_aggr_index_shift = {}
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")): for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
ep_idx_to_aggr_ep_idx = {} datasets_aggr_episode_index_shift[dataset_index] = aggr_meta.total_episodes
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
# populate episodes # populate episodes
for episode_index, episode_dict in meta.episodes.items(): for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift aggr_episode_index = episode_index + aggr_meta.total_episodes
episode_dict["episode_index"] = aggr_episode_index episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict aggr_meta.episodes[aggr_episode_index] = episode_dict
# populate episodes_stats # populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items(): for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift aggr_episode_index = episode_index + aggr_meta.total_episodes
episode_stats["index"].update(
{
"min": episode_stats["index"]["min"] + aggr_meta.total_frames,
"max": episode_stats["index"]["max"] + aggr_meta.total_frames,
"mean": episode_stats["index"]["mean"] + aggr_meta.total_frames,
}
)
episode_stats["episode_index"].update(
{
"min": np.array([aggr_episode_index]),
"max": np.array([aggr_episode_index]),
"mean": np.array([aggr_episode_index]),
}
)
df = pd.read_parquet(meta.root / meta.get_data_file_path(episode_index))
df["task_index"] = df["task_index"].map(datasets_task_index_to_aggr_task_index[dataset_index])
episode_stats["task_index"].update(
{
"min": np.array([df["task_index"].min()]),
"max": np.array([df["task_index"].max()]),
"mean": np.array([df["task_index"].mean()]),
"std": np.array([df["task_index"].std()]),
}
)
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
# populate info # populate info
@@ -146,11 +162,9 @@ class AggregateDatasets(PipelineStep):
aggr_meta.info["total_frames"] += meta.total_frames aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logger.info("Write meta data") logger.info("Write meta data")
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks) aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1) aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_meta.total_episodes - 1) + 1
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"} aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode # create a new episodes jsonl with updated episode_index using write_episode
@@ -168,7 +182,6 @@ class AggregateDatasets(PipelineStep):
write_info(aggr_meta.info, aggr_meta.root) write_info(aggr_meta.info, aggr_meta.root)
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
self.datasets_aggr_index_shift = datasets_aggr_index_shift self.datasets_aggr_index_shift = datasets_aggr_index_shift
@@ -187,7 +200,7 @@ class AggregateDatasets(PipelineStep):
with self.track_time("aggregating dataset"): with self.track_time("aggregating dataset"):
logger.info("Copy data") logger.info("Copy data")
for episode_index in range(meta.total_episodes): for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index] aggr_episode_index = episode_index + aggr_episode_index_shift
data_path = meta.root / meta.get_data_file_path(episode_index) data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index) aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True) aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
@@ -258,6 +271,7 @@ def main(
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot") aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
else: else:
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot" aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
aggregate_output_path = aggregate_output_path.resolve()
if debug: if debug:
executor = "local" executor = "local"