🚀 add time tracking for libero2lerobot

This commit is contained in:
Tavish
2025-07-01 14:27:11 +08:00
parent 8fdfe7f3cf
commit e486ea3612
+14 -5
View File
@@ -33,7 +33,11 @@ def setup_logger():
class SaveLerobotDataset(PipelineStep):
name = "Save Temp LerobotDataset"
type = "libero2lerobot"
def __init__(self, tasks: list[tuple[Path, Path, str]]):
super().__init__()
self.tasks = tasks
def run(self, data=None, rank: int = 0, world_size: int = 1):
@@ -56,6 +60,7 @@ class SaveLerobotDataset(PipelineStep):
raw_dataset = load_local_episodes(input_h5)
for episode_index, episode_data in enumerate(raw_dataset):
with self.track_time("saving episode"):
for frame_data in episode_data:
dataset.add_frame(
frame_data,
@@ -66,11 +71,10 @@ class SaveLerobotDataset(PipelineStep):
class AggregateDatasets(PipelineStep):
def __init__(
self,
raw_dirs: list[Path],
aggregated_dir: Path,
):
name = "Aggregate Datasets"
type = "libero2lerobot"
def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
super().__init__()
self.raw_dirs = raw_dirs
self.aggregated_dir = aggregated_dir
@@ -180,6 +184,7 @@ class AggregateDatasets(PipelineStep):
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
with self.track_time("aggregating dataset"):
logger.info("Copy data")
for episode_index in range(meta.total_episodes):
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
@@ -205,7 +210,11 @@ class AggregateDatasets(PipelineStep):
class DeleteTempData(PipelineStep):
name = "Delete Temp Data"
type = "libero2lerobot"
def __init__(self, temp_dirs: list[Path]):
super().__init__()
self.temp_dirs = temp_dirs
def run(self, data=None, rank: int = 0, world_size: int = 1):