mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-28 20:09:40 +00:00
🔨 fix meta && enhance aggr
This commit is contained in:
+31
-17
@@ -4,6 +4,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import ray
|
import ray
|
||||||
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
||||||
@@ -115,30 +116,45 @@ class AggregateDatasets(PipelineStep):
|
|||||||
|
|
||||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||||
|
|
||||||
datasets_ep_idx_to_aggr_ep_idx = {}
|
|
||||||
datasets_aggr_episode_index_shift = {}
|
datasets_aggr_episode_index_shift = {}
|
||||||
datasets_aggr_index_shift = {}
|
datasets_aggr_index_shift = {}
|
||||||
aggr_episode_index_shift = 0
|
|
||||||
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
|
for dataset_index, meta in enumerate(tqdm(all_metadata, desc="Aggregate episodes and global index")):
|
||||||
ep_idx_to_aggr_ep_idx = {}
|
datasets_aggr_episode_index_shift[dataset_index] = aggr_meta.total_episodes
|
||||||
|
|
||||||
for episode_index in range(meta.total_episodes):
|
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
|
||||||
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
|
||||||
|
|
||||||
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
|
||||||
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
|
||||||
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
|
datasets_aggr_index_shift[dataset_index] = aggr_meta.total_frames
|
||||||
|
|
||||||
# populate episodes
|
# populate episodes
|
||||||
for episode_index, episode_dict in meta.episodes.items():
|
for episode_index, episode_dict in meta.episodes.items():
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
aggr_episode_index = episode_index + aggr_meta.total_episodes
|
||||||
episode_dict["episode_index"] = aggr_episode_index
|
episode_dict["episode_index"] = aggr_episode_index
|
||||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
||||||
|
|
||||||
# populate episodes_stats
|
# populate episodes_stats
|
||||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
for episode_index, episode_stats in meta.episodes_stats.items():
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
aggr_episode_index = episode_index + aggr_meta.total_episodes
|
||||||
|
episode_stats["index"].update(
|
||||||
|
{
|
||||||
|
"min": episode_stats["index"]["min"] + aggr_meta.total_frames,
|
||||||
|
"max": episode_stats["index"]["max"] + aggr_meta.total_frames,
|
||||||
|
"mean": episode_stats["index"]["mean"] + aggr_meta.total_frames,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
episode_stats["episode_index"].update(
|
||||||
|
{
|
||||||
|
"min": np.array([aggr_episode_index]),
|
||||||
|
"max": np.array([aggr_episode_index]),
|
||||||
|
"mean": np.array([aggr_episode_index]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
df = pd.read_parquet(meta.root / meta.get_data_file_path(episode_index))
|
||||||
|
df["task_index"] = df["task_index"].map(datasets_task_index_to_aggr_task_index[dataset_index])
|
||||||
|
episode_stats["task_index"].update(
|
||||||
|
{
|
||||||
|
"min": np.array([df["task_index"].min()]),
|
||||||
|
"max": np.array([df["task_index"].max()]),
|
||||||
|
"mean": np.array([df["task_index"].mean()]),
|
||||||
|
"std": np.array([df["task_index"].std()]),
|
||||||
|
}
|
||||||
|
)
|
||||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
||||||
|
|
||||||
# populate info
|
# populate info
|
||||||
@@ -146,11 +162,9 @@ class AggregateDatasets(PipelineStep):
|
|||||||
aggr_meta.info["total_frames"] += meta.total_frames
|
aggr_meta.info["total_frames"] += meta.total_frames
|
||||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||||
|
|
||||||
aggr_episode_index_shift += meta.total_episodes
|
|
||||||
|
|
||||||
logger.info("Write meta data")
|
logger.info("Write meta data")
|
||||||
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
||||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_meta.total_episodes - 1) + 1
|
||||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||||
|
|
||||||
# create a new episodes jsonl with updated episode_index using write_episode
|
# create a new episodes jsonl with updated episode_index using write_episode
|
||||||
@@ -168,7 +182,6 @@ class AggregateDatasets(PipelineStep):
|
|||||||
write_info(aggr_meta.info, aggr_meta.root)
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
||||||
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
|
||||||
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
||||||
self.datasets_aggr_index_shift = datasets_aggr_index_shift
|
self.datasets_aggr_index_shift = datasets_aggr_index_shift
|
||||||
|
|
||||||
@@ -187,7 +200,7 @@ class AggregateDatasets(PipelineStep):
|
|||||||
with self.track_time("aggregating dataset"):
|
with self.track_time("aggregating dataset"):
|
||||||
logger.info("Copy data")
|
logger.info("Copy data")
|
||||||
for episode_index in range(meta.total_episodes):
|
for episode_index in range(meta.total_episodes):
|
||||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -258,6 +271,7 @@ def main(
|
|||||||
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
||||||
else:
|
else:
|
||||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||||
|
aggregate_output_path = aggregate_output_path.resolve()
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
executor = "local"
|
executor = "local"
|
||||||
|
|||||||
Reference in New Issue
Block a user