🔨 fix meta && enhance aggr

This commit is contained in:
Tavish
2025-07-11 20:05:12 +08:00
parent 30981e2111
commit cbe3e13375
+31 -17
View File
@@ -4,6 +4,7 @@ import re
import shutil
from pathlib import Path
import numpy as np
import pandas as pd
import ray
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
@@ -115,30 +116,45 @@ class AggregateDatasets(PipelineStep):
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
datasets_ep_idx_to_aggr_ep_idx = {}
datasets_aggr_episode_index_shift = {}
datasets_aggr_index_shift = {}
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
ep_idx_to_aggr_ep_idx = {}
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
datasets_aggr_episode_index_shift[dataset_index] = aggr_meta.total_episodes
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_episode_index = episode_index + aggr_meta.total_episodes
episode_dict["episode_index"] = aggr_episode_index
aggr_meta.episodes[aggr_episode_index] = episode_dict
# populate episodes_stats
for episode_index, episode_stats in meta.episodes_stats.items():
aggr_episode_index = episode_index + aggr_episode_index_shift
aggr_episode_index = episode_index + aggr_meta.total_episodes
episode_stats["index"].update(
{
"min": episode_stats["index"]["min"] + aggr_meta.total_frames,
"max": episode_stats["index"]["max"] + aggr_meta.total_frames,
"mean": episode_stats["index"]["mean"] + aggr_meta.total_frames,
}
)
episode_stats["episode_index"].update(
{
"min": np.array([aggr_episode_index]),
"max": np.array([aggr_episode_index]),
"mean": np.array([aggr_episode_index]),
}
)
df = pd.read_parquet(meta.root / meta.get_data_file_path(episode_index))
df["task_index"] = df["task_index"].map(datasets_task_index_to_aggr_task_index[dataset_index])
episode_stats["task_index"].update(
{
"min": np.array([df["task_index"].min()]),
"max": np.array([df["task_index"].max()]),
"mean": np.array([df["task_index"].mean()]),
"std": np.array([df["task_index"].std()]),
}
)
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
# populate info
@@ -146,11 +162,9 @@ class AggregateDatasets(PipelineStep):
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logger.info("Write meta data")
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_meta.total_episodes - 1) + 1
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
# create a new episodes jsonl with updated episode_index using write_episode
@@ -168,7 +182,6 @@ class AggregateDatasets(PipelineStep):
write_info(aggr_meta.info, aggr_meta.root)
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
self.datasets_aggr_index_shift = datasets_aggr_index_shift
@@ -187,7 +200,7 @@ class AggregateDatasets(PipelineStep):
with self.track_time("aggregating dataset"):
logger.info("Copy data")
for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
aggr_episode_index = episode_index + aggr_episode_index_shift
data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
@@ -258,6 +271,7 @@ def main(
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
else:
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
aggregate_output_path = aggregate_output_path.resolve()
if debug:
executor = "local"