🚀 add time tracking for libero2lerobot

This commit is contained in:
Tavish
2025-07-01 14:27:11 +08:00
parent 8fdfe7f3cf
commit e486ea3612
+14 -5
View File
@@ -33,7 +33,11 @@ def setup_logger():
class SaveLerobotDataset(PipelineStep): class SaveLerobotDataset(PipelineStep):
name = "Save Temp LerobotDataset"
type = "libero2lerobot"
def __init__(self, tasks: list[tuple[Path, Path, str]]): def __init__(self, tasks: list[tuple[Path, Path, str]]):
super().__init__()
self.tasks = tasks self.tasks = tasks
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
@@ -56,6 +60,7 @@ class SaveLerobotDataset(PipelineStep):
raw_dataset = load_local_episodes(input_h5) raw_dataset = load_local_episodes(input_h5)
for episode_index, episode_data in enumerate(raw_dataset): for episode_index, episode_data in enumerate(raw_dataset):
with self.track_time("saving episode"):
for frame_data in episode_data: for frame_data in episode_data:
dataset.add_frame( dataset.add_frame(
frame_data, frame_data,
@@ -66,11 +71,10 @@ class SaveLerobotDataset(PipelineStep):
class AggregateDatasets(PipelineStep): class AggregateDatasets(PipelineStep):
def __init__( name = "Aggregate Datasets"
self, type = "libero2lerobot"
raw_dirs: list[Path],
aggregated_dir: Path, def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
):
super().__init__() super().__init__()
self.raw_dirs = raw_dirs self.raw_dirs = raw_dirs
self.aggregated_dir = aggregated_dir self.aggregated_dir = aggregated_dir
@@ -180,6 +184,7 @@ class AggregateDatasets(PipelineStep):
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index] aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index] task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
with self.track_time("aggregating dataset"):
logger.info("Copy data") logger.info("Copy data")
for episode_index in range(meta.total_episodes): for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index] aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
@@ -205,7 +210,11 @@ class AggregateDatasets(PipelineStep):
class DeleteTempData(PipelineStep): class DeleteTempData(PipelineStep):
name = "Delete Temp Data"
type = "libero2lerobot"
def __init__(self, temp_dirs: list[Path]): def __init__(self, temp_dirs: list[Path]):
super().__init__()
self.temp_dirs = temp_dirs self.temp_dirs = temp_dirs
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):