mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-16 14:39:41 +00:00
🔨 fix meta && enhance aggr
This commit is contained in:
+31
-17
@@ -4,6 +4,7 @@ import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
||||
@@ -115,30 +116,45 @@ class AggregateDatasets(PipelineStep):
|
||||
|
||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||
|
||||
datasets_ep_idx_to_aggr_ep_idx = {}
|
||||
datasets_aggr_episode_index_shift = {}
|
||||
datasets_aggr_index_shift = {}
|
||||
aggr_episode_index_shift = 0
|
||||
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
|
||||
ep_idx_to_aggr_ep_idx = {}
|
||||
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
||||
|
||||
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
||||
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
||||
datasets_aggr_episode_index_shift[dataset_index] = aggr_meta.total_episodes
|
||||
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
|
||||
|
||||
# populate episodes
|
||||
for episode_index, episode_dict in meta.episodes.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
aggr_episode_index = episode_index + aggr_meta.total_episodes
|
||||
episode_dict["episode_index"] = aggr_episode_index
|
||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
||||
|
||||
# populate episodes_stats
|
||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
aggr_episode_index = episode_index + aggr_meta.total_episodes
|
||||
episode_stats["index"].update(
|
||||
{
|
||||
"min": episode_stats["index"]["min"] + aggr_meta.total_frames,
|
||||
"max": episode_stats["index"]["max"] + aggr_meta.total_frames,
|
||||
"mean": episode_stats["index"]["mean"] + aggr_meta.total_frames,
|
||||
}
|
||||
)
|
||||
episode_stats["episode_index"].update(
|
||||
{
|
||||
"min": np.array([aggr_episode_index]),
|
||||
"max": np.array([aggr_episode_index]),
|
||||
"mean": np.array([aggr_episode_index]),
|
||||
}
|
||||
)
|
||||
df = pd.read_parquet(meta.root / meta.get_data_file_path(episode_index))
|
||||
df["task_index"] = df["task_index"].map(datasets_task_index_to_aggr_task_index[dataset_index])
|
||||
episode_stats["task_index"].update(
|
||||
{
|
||||
"min": np.array([df["task_index"].min()]),
|
||||
"max": np.array([df["task_index"].max()]),
|
||||
"mean": np.array([df["task_index"].mean()]),
|
||||
"std": np.array([df["task_index"].std()]),
|
||||
}
|
||||
)
|
||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
||||
|
||||
# populate info
|
||||
@@ -146,11 +162,9 @@ class AggregateDatasets(PipelineStep):
|
||||
aggr_meta.info["total_frames"] += meta.total_frames
|
||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||
|
||||
aggr_episode_index_shift += meta.total_episodes
|
||||
|
||||
logger.info("Write meta data")
|
||||
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_meta.total_episodes - 1) + 1
|
||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||
|
||||
# create a new episodes jsonl with updated episode_index using write_episode
|
||||
@@ -168,7 +182,6 @@ class AggregateDatasets(PipelineStep):
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
||||
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
||||
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
||||
self.datasets_aggr_index_shift = datasets_aggr_index_shift
|
||||
|
||||
@@ -187,7 +200,7 @@ class AggregateDatasets(PipelineStep):
|
||||
with self.track_time("aggregating dataset"):
|
||||
logger.info("Copy data")
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -258,6 +271,7 @@ def main(
|
||||
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
||||
else:
|
||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||
aggregate_output_path = aggregate_output_path.resolve()
|
||||
|
||||
if debug:
|
||||
executor = "local"
|
||||
|
||||
Reference in New Issue
Block a user