mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-17 15:09:44 +00:00
🚀 add time tracking for libero2lerobot
This commit is contained in:
+41
-32
@@ -33,7 +33,11 @@ def setup_logger():
|
|||||||
|
|
||||||
|
|
||||||
class SaveLerobotDataset(PipelineStep):
|
class SaveLerobotDataset(PipelineStep):
|
||||||
|
name = "Save Temp LerobotDataset"
|
||||||
|
type = "libero2lerobot"
|
||||||
|
|
||||||
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
||||||
|
super().__init__()
|
||||||
self.tasks = tasks
|
self.tasks = tasks
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
@@ -56,21 +60,21 @@ class SaveLerobotDataset(PipelineStep):
|
|||||||
|
|
||||||
raw_dataset = load_local_episodes(input_h5)
|
raw_dataset = load_local_episodes(input_h5)
|
||||||
for episode_index, episode_data in enumerate(raw_dataset):
|
for episode_index, episode_data in enumerate(raw_dataset):
|
||||||
for frame_data in episode_data:
|
with self.track_time("saving episode"):
|
||||||
dataset.add_frame(
|
for frame_data in episode_data:
|
||||||
frame_data,
|
dataset.add_frame(
|
||||||
task=task_instruction,
|
frame_data,
|
||||||
)
|
task=task_instruction,
|
||||||
dataset.save_episode()
|
)
|
||||||
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
|
dataset.save_episode()
|
||||||
|
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
|
||||||
|
|
||||||
|
|
||||||
class AggregateDatasets(PipelineStep):
|
class AggregateDatasets(PipelineStep):
|
||||||
def __init__(
|
name = "Aggregate Datasets"
|
||||||
self,
|
type = "libero2lerobot"
|
||||||
raw_dirs: list[Path],
|
|
||||||
aggregated_dir: Path,
|
def __init__(self, raw_dirs: list[Path], aggregated_dir: Path):
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.raw_dirs = raw_dirs
|
self.raw_dirs = raw_dirs
|
||||||
self.aggregated_dir = aggregated_dir
|
self.aggregated_dir = aggregated_dir
|
||||||
@@ -180,32 +184,37 @@ class AggregateDatasets(PipelineStep):
|
|||||||
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
|
aggr_index_shift = self.datasets_aggr_index_shift[dataset_index]
|
||||||
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
|
task_index_to_aggr_task_index = self.datasets_task_index_to_aggr_task_index[dataset_index]
|
||||||
|
|
||||||
logger.info("Copy data")
|
with self.track_time("aggregating dataset"):
|
||||||
for episode_index in range(meta.total_episodes):
|
logger.info("Copy data")
|
||||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
for episode_index in range(meta.total_episodes):
|
||||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
||||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||||
|
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# update index, episode_index and task_index
|
# update index, episode_index and task_index
|
||||||
df = pd.read_parquet(data_path)
|
df = pd.read_parquet(data_path)
|
||||||
df["index"] += aggr_index_shift
|
df["index"] += aggr_index_shift
|
||||||
df["episode_index"] += aggr_episode_index_shift
|
df["episode_index"] += aggr_episode_index_shift
|
||||||
df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index)
|
df["task_index"] = df["task_index"].map(task_index_to_aggr_task_index)
|
||||||
df.to_parquet(aggr_data_path)
|
df.to_parquet(aggr_data_path)
|
||||||
|
|
||||||
logger.info("Copy videos")
|
logger.info("Copy videos")
|
||||||
for episode_index in range(meta.total_episodes):
|
for episode_index in range(meta.total_episodes):
|
||||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
for vid_key in meta.video_keys:
|
for vid_key in meta.video_keys:
|
||||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
||||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
||||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(video_path, aggr_video_path)
|
shutil.copy(video_path, aggr_video_path)
|
||||||
|
|
||||||
|
|
||||||
class DeleteTempData(PipelineStep):
|
class DeleteTempData(PipelineStep):
|
||||||
|
name = "Delete Temp Data"
|
||||||
|
type = "libero2lerobot"
|
||||||
|
|
||||||
def __init__(self, temp_dirs: list[Path]):
|
def __init__(self, temp_dirs: list[Path]):
|
||||||
|
super().__init__()
|
||||||
self.temp_dirs = temp_dirs
|
self.temp_dirs = temp_dirs
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
|||||||
Reference in New Issue
Block a user