save action_config in each episode

This commit is contained in:
Tavish
2025-04-18 13:48:29 +08:00
parent 2b0f699560
commit 6cd646e91c
2 changed files with 118 additions and 42 deletions
+7 -10
View File
@@ -6,14 +6,11 @@ import numpy as np
from PIL import Image
def get_task_instruction(task_json_path: str) -> dict:
"""Get task language instruction"""
def get_task_info(task_json_path: str) -> dict:
with open(task_json_path, "r") as f:
task_info = json.load(f)
task_name = task_info[0]["task_name"]
task_init_scene = task_info[0]["init_scene_text"]
task_instruction = f"{task_name}.{task_init_scene}"
return task_instruction
task_info: list = json.load(f)
task_info.sort(key=lambda episode: episode["episode_id"])
return task_info
def load_depths(root_dir: str, camera_name: str):
@@ -23,7 +20,7 @@ def load_depths(root_dir: str, camera_name: str):
def load_local_dataset(
episode_id: int, src_path: str, task_id: int, task_name: str, save_depth: bool, AgiBotWorld_CONFIG: dict
episode_id: int, src_path: str, task_id: int, task_instruction: str, save_depth: bool, AgiBotWorld_CONFIG: dict
) -> tuple[list, dict]:
"""Load local dataset and return a dict with observations and actions"""
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
@@ -79,7 +76,7 @@ def load_local_dataset(
)
for key, value in action.items()
},
"task": task_name,
"task": task_instruction,
}
for i in range(num_frames)
]
@@ -91,4 +88,4 @@ def load_local_dataset(
for key in AgiBotWorld_CONFIG["images"]
if "depth" not in key
}
return frames, videos
return episode_id, frames, videos