mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-22 09:29:44 +00:00
🚀 add time tracking for libero2lerobot
This commit is contained in:
@@ -33,7 +33,11 @@ def setup_logger():
|
||||
|
||||
|
||||
class SaveLerobotDataset(PipelineStep):
|
||||
name = "Save Temp LerobotDataset"
|
||||
type = "libero2lerobot"
|
||||
|
||||
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
||||
super().__init__()
|
||||
self.tasks = tasks
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
@@ -56,6 +60,7 @@ class SaveLerobotDataset(PipelineStep):
|
||||
|
||||
raw_dataset = load_local_episodes(input_h5)
|
||||
for episode_index, episode_data in enumerate(raw_dataset):
|
||||
with self.track_time("saving episode"):
|
||||
for frame_data in episode_data:
|
||||
dataset.add_frame(
|
||||
frame_data,
|
||||
@@ -66,11 +71,10 @@ class SaveLerobotDataset(PipelineStep):
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dirs: list[Path],
|
||||
aggregated_dir: Path,
|
||||
):
|
||||
name = "Aggregate Datasets"
|
||||
type = "libero2lerobot"
|
||||
|
||||
def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
|
||||
super().__init__()
|
||||
self.raw_dirs = raw_dirs
|
||||
self.aggregated_dir = aggregated_dir
|
||||
@@ -180,6 +184,7 @@ class AggregateDatasets(PipelineStep):
|
||||
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
|
||||
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
|
||||
|
||||
with self.track_time("aggregating dataset"):
|
||||
logger.info("Copy data")
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
||||
@@ -205,7 +210,11 @@ class AggregateDatasets(PipelineStep):
|
||||
|
||||
|
||||
class DeleteTempData(PipelineStep):
|
||||
name = "Delete Temp Data"
|
||||
type = "libero2lerobot"
|
||||
|
||||
def __init__(self, temp_dirs: list[Path]):
|
||||
super().__init__()
|
||||
self.temp_dirs = temp_dirs
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
|
||||
Reference in New Issue
Block a user