🚀 add time tracking for libero2lerobot

This commit is contained in:
Tavish
2025-07-01 14:27:11 +08:00
parent 8fdfe7f3cf
commit e486ea3612
+41 -32
View File
@@ -33,7 +33,11 @@ def setup_logger():
class SaveLerobotDataset(PipelineStep): class SaveLerobotDataset(PipelineStep):
name = "Save Temp LerobotDataset"
type = "libero2lerobot"
def __init__(self, tasks: list[tuple[Path, Path, str]]): def __init__(self, tasks: list[tuple[Path, Path, str]]):
super().__init__()
self.tasks = tasks self.tasks = tasks
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
@@ -56,21 +60,21 @@ class SaveLerobotDataset(PipelineStep):
raw_dataset = load_local_episodes(input_h5) raw_dataset = load_local_episodes(input_h5)
for episode_index, episode_data in enumerate(raw_dataset): for episode_index, episode_data in enumerate(raw_dataset):
for frame_data in episode_data: with self.track_time("saving episode"):
dataset.add_frame( for frame_data in episode_data:
frame_data, dataset.add_frame(
task=task_instruction, frame_data,
) task=task_instruction,
dataset.save_episode() )
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}") dataset.save_episode()
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
class AggregateDatasets(PipelineStep): class AggregateDatasets(PipelineStep):
def __init__( name = "Aggregate Datasets"
self, type = "libero2lerobot"
raw_dirs: list[Path],
aggregated_dir: Path, def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
):
super().__init__() super().__init__()
self.raw_dirs = raw_dirs self.raw_dirs = raw_dirs
self.aggregated_dir = aggregated_dir self.aggregated_dir = aggregated_dir
@@ -180,32 +184,37 @@ class AggregateDatasets(PipelineStep):
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index] aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index] task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
logger.info("Copy data") with self.track_time("aggregating dataset"):
for episode_index in range(meta.total_episodes): logger.info("Copy data")
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index] for episode_index in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(episode_index) aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index) data_path = meta.root / meta.get_data_file_path(episode_index)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True) aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
# update index, episode_index and task_index # update index, episode_index and task_index
df = pd.read_parquet(data_path) df = pd.read_parquet(data_path)
df["index"] += aggr_index_shift df["index"] += aggr_index_shift
df["episode_index"] += aggr_episode_index_shift df["episode_index"] += aggr_episode_index_shift
df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index) df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index)
df.to_parquet(aggr_data_path) df.to_parquet(aggr_data_path)
logger.info("Copy videos") logger.info("Copy videos")
for episode_index in range(meta.total_episodes): for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift aggr_episode_index = episode_index + aggr_episode_index_shift
for vid_key in meta.video_keys: for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key) video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key) aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True) aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path) shutil.copy(video_path, aggr_video_path)
class DeleteTempData(PipelineStep): class DeleteTempData(PipelineStep):
name = "Delete Temp Data"
type = "libero2lerobot"
def __init__(self, temp_dirs: list[Path]): def __init__(self, temp_dirs: list[Path]):
super().__init__()
self.temp_dirs = temp_dirs self.temp_dirs = temp_dirs
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):